mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: support filters in file search (#2472)
# What does this PR do? Move to use vector_stores.search for file search tool in Responses, which supports filters. closes #2435 ## Test Plan Added e2e test with fitlers. myenv ❯ llama stack run llama_stack/templates/fireworks/run.yaml pytest -sv tests/verifications/openai_api/test_responses.py \ -k 'file_search and filters' \ --base-url=http://localhost:8321/v1/openai/v1 \ --model=meta-llama/Llama-3.3-70B-Instruct
This commit is contained in:
parent
fd37a50e6a
commit
db2cd9e8f3
13 changed files with 449 additions and 63 deletions
26
docs/_static/llama-stack-spec.html
vendored
26
docs/_static/llama-stack-spec.html
vendored
|
@ -7296,7 +7296,7 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"title": "FileSearchRankingOptions"
|
"title": "SearchRankingOptions"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -13478,28 +13478,16 @@
|
||||||
},
|
},
|
||||||
"ranking_options": {
|
"ranking_options": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"additionalProperties": {
|
"properties": {
|
||||||
"oneOf": [
|
"ranker": {
|
||||||
{
|
|
||||||
"type": "null"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "boolean"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "number"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
{
|
"score_threshold": {
|
||||||
"type": "array"
|
"type": "number",
|
||||||
},
|
"default": 0.0
|
||||||
{
|
|
||||||
"type": "object"
|
|
||||||
}
|
}
|
||||||
]
|
|
||||||
},
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
"description": "Ranking options for fine-tuning the search results."
|
"description": "Ranking options for fine-tuning the search results."
|
||||||
},
|
},
|
||||||
"rewrite_query": {
|
"rewrite_query": {
|
||||||
|
|
17
docs/_static/llama-stack-spec.yaml
vendored
17
docs/_static/llama-stack-spec.yaml
vendored
|
@ -5179,7 +5179,7 @@ components:
|
||||||
type: number
|
type: number
|
||||||
default: 0.0
|
default: 0.0
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
title: FileSearchRankingOptions
|
title: SearchRankingOptions
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- type
|
- type
|
||||||
|
@ -9408,14 +9408,13 @@ components:
|
||||||
Maximum number of results to return (1 to 50 inclusive, default 10).
|
Maximum number of results to return (1 to 50 inclusive, default 10).
|
||||||
ranking_options:
|
ranking_options:
|
||||||
type: object
|
type: object
|
||||||
additionalProperties:
|
properties:
|
||||||
oneOf:
|
ranker:
|
||||||
- type: 'null'
|
type: string
|
||||||
- type: boolean
|
score_threshold:
|
||||||
- type: number
|
type: number
|
||||||
- type: string
|
default: 0.0
|
||||||
- type: array
|
additionalProperties: false
|
||||||
- type: object
|
|
||||||
description: >-
|
description: >-
|
||||||
Ranking options for fine-tuning the search results.
|
Ranking options for fine-tuning the search results.
|
||||||
rewrite_query:
|
rewrite_query:
|
||||||
|
|
|
@ -9,6 +9,7 @@ from typing import Annotated, Any, Literal
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from llama_stack.apis.vector_io import SearchRankingOptions as FileSearchRankingOptions
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
|
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
|
||||||
|
@ -400,11 +401,6 @@ class OpenAIResponseInputToolFunction(BaseModel):
|
||||||
strict: bool | None = None
|
strict: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
class FileSearchRankingOptions(BaseModel):
|
|
||||||
ranker: str | None = None
|
|
||||||
score_threshold: float | None = Field(default=0.0, ge=0.0, le=1.0)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseInputToolFileSearch(BaseModel):
|
class OpenAIResponseInputToolFileSearch(BaseModel):
|
||||||
type: Literal["file_search"] = "file_search"
|
type: Literal["file_search"] = "file_search"
|
||||||
|
|
|
@ -157,6 +157,11 @@ VectorStoreChunkingStrategy = Annotated[
|
||||||
register_schema(VectorStoreChunkingStrategy, name="VectorStoreChunkingStrategy")
|
register_schema(VectorStoreChunkingStrategy, name="VectorStoreChunkingStrategy")
|
||||||
|
|
||||||
|
|
||||||
|
class SearchRankingOptions(BaseModel):
|
||||||
|
ranker: str | None = None
|
||||||
|
score_threshold: float | None = Field(default=0.0, ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class VectorStoreFileLastError(BaseModel):
|
class VectorStoreFileLastError(BaseModel):
|
||||||
code: Literal["server_error"] | Literal["rate_limit_exceeded"]
|
code: Literal["server_error"] | Literal["rate_limit_exceeded"]
|
||||||
|
@ -319,7 +324,7 @@ class VectorIO(Protocol):
|
||||||
query: str | list[str],
|
query: str | list[str],
|
||||||
filters: dict[str, Any] | None = None,
|
filters: dict[str, Any] | None = None,
|
||||||
max_num_results: int | None = 10,
|
max_num_results: int | None = 10,
|
||||||
ranking_options: dict[str, Any] | None = None,
|
ranking_options: SearchRankingOptions | None = None,
|
||||||
rewrite_query: bool | None = False,
|
rewrite_query: bool | None = False,
|
||||||
) -> VectorStoreSearchResponsePage:
|
) -> VectorStoreSearchResponsePage:
|
||||||
"""Search for chunks in a vector store.
|
"""Search for chunks in a vector store.
|
||||||
|
|
|
@ -14,6 +14,7 @@ from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import (
|
||||||
Chunk,
|
Chunk,
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
|
SearchRankingOptions,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
VectorStoreDeleteResponse,
|
VectorStoreDeleteResponse,
|
||||||
VectorStoreListResponse,
|
VectorStoreListResponse,
|
||||||
|
@ -246,7 +247,7 @@ class VectorIORouter(VectorIO):
|
||||||
query: str | list[str],
|
query: str | list[str],
|
||||||
filters: dict[str, Any] | None = None,
|
filters: dict[str, Any] | None = None,
|
||||||
max_num_results: int | None = 10,
|
max_num_results: int | None = 10,
|
||||||
ranking_options: dict[str, Any] | None = None,
|
ranking_options: SearchRankingOptions | None = None,
|
||||||
rewrite_query: bool | None = False,
|
rewrite_query: bool | None = False,
|
||||||
) -> VectorStoreSearchResponsePage:
|
) -> VectorStoreSearchResponsePage:
|
||||||
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
||||||
|
|
|
@ -85,6 +85,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
tool_groups_api=self.tool_groups_api,
|
tool_groups_api=self.tool_groups_api,
|
||||||
tool_runtime_api=self.tool_runtime_api,
|
tool_runtime_api=self.tool_runtime_api,
|
||||||
responses_store=self.responses_store,
|
responses_store=self.responses_store,
|
||||||
|
vector_io_api=self.vector_io_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_agent(
|
async def create_agent(
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -42,6 +43,7 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseText,
|
OpenAIResponseText,
|
||||||
OpenAIResponseTextFormat,
|
OpenAIResponseTextFormat,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.common.content_types import TextContentItem
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
OpenAIAssistantMessageParam,
|
OpenAIAssistantMessageParam,
|
||||||
|
@ -64,7 +66,8 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIToolMessageParam,
|
OpenAIToolMessageParam,
|
||||||
OpenAIUserMessageParam,
|
OpenAIUserMessageParam,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.tools import RAGQueryConfig, ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||||
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||||
|
@ -214,11 +217,13 @@ class OpenAIResponsesImpl:
|
||||||
tool_groups_api: ToolGroups,
|
tool_groups_api: ToolGroups,
|
||||||
tool_runtime_api: ToolRuntime,
|
tool_runtime_api: ToolRuntime,
|
||||||
responses_store: ResponsesStore,
|
responses_store: ResponsesStore,
|
||||||
|
vector_io_api: VectorIO, # VectorIO
|
||||||
):
|
):
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
self.tool_runtime_api = tool_runtime_api
|
self.tool_runtime_api = tool_runtime_api
|
||||||
self.responses_store = responses_store
|
self.responses_store = responses_store
|
||||||
|
self.vector_io_api = vector_io_api
|
||||||
|
|
||||||
async def _prepend_previous_response(
|
async def _prepend_previous_response(
|
||||||
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
|
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
|
||||||
|
@ -666,6 +671,71 @@ class OpenAIResponsesImpl:
|
||||||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||||
return chat_tools, mcp_tool_to_server, mcp_list_message
|
return chat_tools, mcp_tool_to_server, mcp_list_message
|
||||||
|
|
||||||
|
async def _execute_knowledge_search_via_vector_store(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
response_file_search_tool: OpenAIResponseInputToolFileSearch,
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
"""Execute knowledge search using vector_stores.search API with filters support."""
|
||||||
|
search_results = []
|
||||||
|
|
||||||
|
# Create search tasks for all vector stores
|
||||||
|
async def search_single_store(vector_store_id):
|
||||||
|
try:
|
||||||
|
search_response = await self.vector_io_api.openai_search_vector_store(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
query=query,
|
||||||
|
filters=response_file_search_tool.filters,
|
||||||
|
max_num_results=response_file_search_tool.max_num_results,
|
||||||
|
ranking_options=response_file_search_tool.ranking_options,
|
||||||
|
rewrite_query=False,
|
||||||
|
)
|
||||||
|
return search_response.data
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Run all searches in parallel using gather
|
||||||
|
search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids]
|
||||||
|
all_results = await asyncio.gather(*search_tasks)
|
||||||
|
|
||||||
|
# Flatten results
|
||||||
|
for results in all_results:
|
||||||
|
search_results.extend(results)
|
||||||
|
|
||||||
|
# Convert search results to tool result format matching memory.py
|
||||||
|
# Format the results as interleaved content similar to memory.py
|
||||||
|
content_items = []
|
||||||
|
content_items.append(
|
||||||
|
TextContentItem(
|
||||||
|
text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, result_item in enumerate(search_results):
|
||||||
|
chunk_text = result_item.content[0].text if result_item.content else ""
|
||||||
|
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
|
||||||
|
if result_item.attributes:
|
||||||
|
metadata_text += f", attributes: {result_item.attributes}"
|
||||||
|
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
|
||||||
|
content_items.append(TextContentItem(text=text_content))
|
||||||
|
|
||||||
|
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||||
|
content_items.append(
|
||||||
|
TextContentItem(
|
||||||
|
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ToolInvocationResult(
|
||||||
|
content=content_items,
|
||||||
|
metadata={
|
||||||
|
"document_ids": [r.file_id for r in search_results],
|
||||||
|
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||||
|
"scores": [r.score for r in search_results],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
async def _execute_tool_call(
|
async def _execute_tool_call(
|
||||||
self,
|
self,
|
||||||
tool_call: OpenAIChatCompletionToolCall,
|
tool_call: OpenAIChatCompletionToolCall,
|
||||||
|
@ -693,21 +763,19 @@ class OpenAIResponsesImpl:
|
||||||
tool_name=function.name,
|
tool_name=function.name,
|
||||||
kwargs=tool_kwargs,
|
kwargs=tool_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
elif function.name == "knowledge_search":
|
||||||
if function.name == "knowledge_search":
|
|
||||||
response_file_search_tool = next(
|
response_file_search_tool = next(
|
||||||
t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)
|
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), None
|
||||||
)
|
)
|
||||||
if response_file_search_tool:
|
if response_file_search_tool:
|
||||||
if response_file_search_tool.filters:
|
# Use vector_stores.search API instead of knowledge_search tool
|
||||||
logger.warning("Filters are not yet supported for file_search tool")
|
# to support filters and ranking_options
|
||||||
if response_file_search_tool.ranking_options:
|
query = tool_kwargs.get("query", "")
|
||||||
logger.warning("Ranking options are not yet supported for file_search tool")
|
result = await self._execute_knowledge_search_via_vector_store(
|
||||||
tool_kwargs["vector_db_ids"] = response_file_search_tool.vector_store_ids
|
query=query,
|
||||||
tool_kwargs["query_config"] = RAGQueryConfig(
|
response_file_search_tool=response_file_search_tool,
|
||||||
mode="vector",
|
|
||||||
max_chunks=response_file_search_tool.max_num_results,
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
result = await self.tool_runtime_api.invoke_tool(
|
result = await self.tool_runtime_api.invoke_tool(
|
||||||
tool_name=function.name,
|
tool_name=function.name,
|
||||||
kwargs=tool_kwargs,
|
kwargs=tool_kwargs,
|
||||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import (
|
||||||
Chunk,
|
Chunk,
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
|
SearchRankingOptions,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
VectorStoreDeleteResponse,
|
VectorStoreDeleteResponse,
|
||||||
VectorStoreListResponse,
|
VectorStoreListResponse,
|
||||||
|
@ -249,7 +250,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
query: str | list[str],
|
query: str | list[str],
|
||||||
filters: dict[str, Any] | None = None,
|
filters: dict[str, Any] | None = None,
|
||||||
max_num_results: int | None = 10,
|
max_num_results: int | None = 10,
|
||||||
ranking_options: dict[str, Any] | None = None,
|
ranking_options: SearchRankingOptions | None = None,
|
||||||
rewrite_query: bool | None = False,
|
rewrite_query: bool | None = False,
|
||||||
) -> VectorStoreSearchResponsePage:
|
) -> VectorStoreSearchResponsePage:
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
||||||
|
|
|
@ -19,6 +19,7 @@ from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import (
|
||||||
Chunk,
|
Chunk,
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
|
SearchRankingOptions,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
VectorStoreDeleteResponse,
|
VectorStoreDeleteResponse,
|
||||||
VectorStoreListResponse,
|
VectorStoreListResponse,
|
||||||
|
@ -247,7 +248,7 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
query: str | list[str],
|
query: str | list[str],
|
||||||
filters: dict[str, Any] | None = None,
|
filters: dict[str, Any] | None = None,
|
||||||
max_num_results: int | None = 10,
|
max_num_results: int | None = 10,
|
||||||
ranking_options: dict[str, Any] | None = None,
|
ranking_options: SearchRankingOptions | None = None,
|
||||||
rewrite_query: bool | None = False,
|
rewrite_query: bool | None = False,
|
||||||
) -> VectorStoreSearchResponsePage:
|
) -> VectorStoreSearchResponsePage:
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import (
|
||||||
Chunk,
|
Chunk,
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
|
SearchRankingOptions,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
VectorStoreDeleteResponse,
|
VectorStoreDeleteResponse,
|
||||||
VectorStoreListResponse,
|
VectorStoreListResponse,
|
||||||
|
@ -249,7 +250,7 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
query: str | list[str],
|
query: str | list[str],
|
||||||
filters: dict[str, Any] | None = None,
|
filters: dict[str, Any] | None = None,
|
||||||
max_num_results: int | None = 10,
|
max_num_results: int | None = 10,
|
||||||
ranking_options: dict[str, Any] | None = None,
|
ranking_options: SearchRankingOptions | None = None,
|
||||||
rewrite_query: bool | None = False,
|
rewrite_query: bool | None = False,
|
||||||
) -> VectorStoreSearchResponsePage:
|
) -> VectorStoreSearchResponsePage:
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||||
|
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.files import Files
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import (
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
|
SearchRankingOptions,
|
||||||
VectorStoreContent,
|
VectorStoreContent,
|
||||||
VectorStoreDeleteResponse,
|
VectorStoreDeleteResponse,
|
||||||
VectorStoreListResponse,
|
VectorStoreListResponse,
|
||||||
|
@ -296,7 +297,7 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
query: str | list[str],
|
query: str | list[str],
|
||||||
filters: dict[str, Any] | None = None,
|
filters: dict[str, Any] | None = None,
|
||||||
max_num_results: int | None = 10,
|
max_num_results: int | None = 10,
|
||||||
ranking_options: dict[str, Any] | None = None,
|
ranking_options: SearchRankingOptions | None = None,
|
||||||
rewrite_query: bool | None = False,
|
rewrite_query: bool | None = False,
|
||||||
# search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
|
# search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
|
||||||
) -> VectorStoreSearchResponsePage:
|
) -> VectorStoreSearchResponsePage:
|
||||||
|
@ -314,7 +315,11 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
search_query = query
|
search_query = query
|
||||||
|
|
||||||
try:
|
try:
|
||||||
score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
|
score_threshold = (
|
||||||
|
ranking_options.score_threshold
|
||||||
|
if ranking_options and ranking_options.score_threshold is not None
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
params = {
|
params = {
|
||||||
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
|
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
|
||||||
"score_threshold": score_threshold,
|
"score_threshold": score_threshold,
|
||||||
|
@ -399,12 +404,49 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
|
|
||||||
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
|
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
|
||||||
"""Check if metadata matches the provided filters."""
|
"""Check if metadata matches the provided filters."""
|
||||||
for key, value in filters.items():
|
if not filters:
|
||||||
|
return True
|
||||||
|
|
||||||
|
filter_type = filters.get("type")
|
||||||
|
|
||||||
|
if filter_type in ["eq", "ne", "gt", "gte", "lt", "lte"]:
|
||||||
|
# Comparison filter
|
||||||
|
key = filters.get("key")
|
||||||
|
value = filters.get("value")
|
||||||
|
|
||||||
if key not in metadata:
|
if key not in metadata:
|
||||||
return False
|
return False
|
||||||
if metadata[key] != value:
|
|
||||||
return False
|
metadata_value = metadata[key]
|
||||||
return True
|
|
||||||
|
if filter_type == "eq":
|
||||||
|
return bool(metadata_value == value)
|
||||||
|
elif filter_type == "ne":
|
||||||
|
return bool(metadata_value != value)
|
||||||
|
elif filter_type == "gt":
|
||||||
|
return bool(metadata_value > value)
|
||||||
|
elif filter_type == "gte":
|
||||||
|
return bool(metadata_value >= value)
|
||||||
|
elif filter_type == "lt":
|
||||||
|
return bool(metadata_value < value)
|
||||||
|
elif filter_type == "lte":
|
||||||
|
return bool(metadata_value <= value)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported filter type: {filter_type}")
|
||||||
|
|
||||||
|
elif filter_type == "and":
|
||||||
|
# All filters must match
|
||||||
|
sub_filters = filters.get("filters", [])
|
||||||
|
return all(self._matches_filters(metadata, f) for f in sub_filters)
|
||||||
|
|
||||||
|
elif filter_type == "or":
|
||||||
|
# At least one filter must match
|
||||||
|
sub_filters = filters.get("filters", [])
|
||||||
|
return any(self._matches_filters(metadata, f) for f in sub_filters)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Unknown filter type, default to no match
|
||||||
|
raise ValueError(f"Unsupported filter type: {filter_type}")
|
||||||
|
|
||||||
async def openai_attach_file_to_vector_store(
|
async def openai_attach_file_to_vector_store(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -71,12 +71,21 @@ def mock_responses_store():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def openai_responses_impl(mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store):
|
def mock_vector_io_api():
|
||||||
|
vector_io_api = AsyncMock()
|
||||||
|
return vector_io_api
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def openai_responses_impl(
|
||||||
|
mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store, mock_vector_io_api
|
||||||
|
):
|
||||||
return OpenAIResponsesImpl(
|
return OpenAIResponsesImpl(
|
||||||
inference_api=mock_inference_api,
|
inference_api=mock_inference_api,
|
||||||
tool_groups_api=mock_tool_groups_api,
|
tool_groups_api=mock_tool_groups_api,
|
||||||
tool_runtime_api=mock_tool_runtime_api,
|
tool_runtime_api=mock_tool_runtime_api,
|
||||||
responses_store=mock_responses_store,
|
responses_store=mock_responses_store,
|
||||||
|
vector_io_api=mock_vector_io_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -714,3 +714,277 @@ def test_response_text_format(request, openai_client, model, provider, verificat
|
||||||
assert "paris" in response.output_text.lower()
|
assert "paris" in response.output_text.lower()
|
||||||
if text_format["type"] == "json_schema":
|
if text_format["type"] == "json_schema":
|
||||||
assert "paris" in json.loads(response.output_text)["capital"].lower()
|
assert "paris" in json.loads(response.output_text)["capital"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vector_store_with_filtered_files(request, openai_client, model, provider, verification_config, tmp_path_factory):
|
||||||
|
"""Create a vector store with multiple files that have different attributes for filtering tests."""
|
||||||
|
if isinstance(openai_client, LlamaStackAsLibraryClient):
|
||||||
|
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||||
|
|
||||||
|
test_name_base = get_base_test_name(request)
|
||||||
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||||
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||||
|
|
||||||
|
vector_store = _new_vector_store(openai_client, "test_vector_store_with_filters")
|
||||||
|
tmp_path = tmp_path_factory.mktemp("filter_test_files")
|
||||||
|
|
||||||
|
# Create multiple files with different attributes
|
||||||
|
files_data = [
|
||||||
|
{
|
||||||
|
"name": "us_marketing_q1.txt",
|
||||||
|
"content": "US promotional campaigns for Q1 2023. Revenue increased by 15% in the US region.",
|
||||||
|
"attributes": {
|
||||||
|
"region": "us",
|
||||||
|
"category": "marketing",
|
||||||
|
"date": 1672531200, # Jan 1, 2023
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "us_engineering_q2.txt",
|
||||||
|
"content": "US technical updates for Q2 2023. New features deployed in the US region.",
|
||||||
|
"attributes": {
|
||||||
|
"region": "us",
|
||||||
|
"category": "engineering",
|
||||||
|
"date": 1680307200, # Apr 1, 2023
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "eu_marketing_q1.txt",
|
||||||
|
"content": "European advertising campaign results for Q1 2023. Strong growth in EU markets.",
|
||||||
|
"attributes": {
|
||||||
|
"region": "eu",
|
||||||
|
"category": "marketing",
|
||||||
|
"date": 1672531200, # Jan 1, 2023
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "asia_sales_q3.txt",
|
||||||
|
"content": "Asia Pacific revenue figures for Q3 2023. Record breaking quarter in Asia.",
|
||||||
|
"attributes": {
|
||||||
|
"region": "asia",
|
||||||
|
"category": "sales",
|
||||||
|
"date": 1688169600, # Jul 1, 2023
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
file_ids = []
|
||||||
|
for file_data in files_data:
|
||||||
|
# Create file
|
||||||
|
file_path = tmp_path / file_data["name"]
|
||||||
|
file_path.write_text(file_data["content"])
|
||||||
|
|
||||||
|
# Upload file
|
||||||
|
file_response = _upload_file(openai_client, file_data["name"], str(file_path))
|
||||||
|
file_ids.append(file_response.id)
|
||||||
|
|
||||||
|
# Attach file to vector store with attributes
|
||||||
|
file_attach_response = openai_client.vector_stores.files.create(
|
||||||
|
vector_store_id=vector_store.id, file_id=file_response.id, attributes=file_data["attributes"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for attachment
|
||||||
|
while file_attach_response.status == "in_progress":
|
||||||
|
time.sleep(0.1)
|
||||||
|
file_attach_response = openai_client.vector_stores.files.retrieve(
|
||||||
|
vector_store_id=vector_store.id,
|
||||||
|
file_id=file_response.id,
|
||||||
|
)
|
||||||
|
assert file_attach_response.status == "completed"
|
||||||
|
|
||||||
|
yield vector_store
|
||||||
|
|
||||||
|
# Cleanup: delete vector store and files
|
||||||
|
try:
|
||||||
|
openai_client.vector_stores.delete(vector_store_id=vector_store.id)
|
||||||
|
for file_id in file_ids:
|
||||||
|
try:
|
||||||
|
openai_client.files.delete(file_id=file_id)
|
||||||
|
except Exception:
|
||||||
|
pass # File might already be deleted
|
||||||
|
except Exception:
|
||||||
|
pass # Best effort cleanup
|
||||||
|
|
||||||
|
|
||||||
|
def test_response_file_search_filter_by_region(openai_client, model, vector_store_with_filtered_files):
|
||||||
|
"""Test file search with region equality filter."""
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "file_search",
|
||||||
|
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||||
|
"filters": {"type": "eq", "key": "region", "value": "us"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = openai_client.responses.create(
|
||||||
|
model=model,
|
||||||
|
input="What are the updates from the US region?",
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
include=["file_search_call.results"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify file search was called with US filter
|
||||||
|
assert len(response.output) > 1
|
||||||
|
assert response.output[0].type == "file_search_call"
|
||||||
|
assert response.output[0].status == "completed"
|
||||||
|
assert response.output[0].results
|
||||||
|
# Should only return US files (not EU or Asia files)
|
||||||
|
for result in response.output[0].results:
|
||||||
|
assert "us" in result.text.lower() or "US" in result.text
|
||||||
|
# Ensure non-US regions are NOT returned
|
||||||
|
assert "european" not in result.text.lower()
|
||||||
|
assert "asia" not in result.text.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_response_file_search_filter_by_category(openai_client, model, vector_store_with_filtered_files):
|
||||||
|
"""Test file search with category equality filter."""
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "file_search",
|
||||||
|
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||||
|
"filters": {"type": "eq", "key": "category", "value": "marketing"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = openai_client.responses.create(
|
||||||
|
model=model,
|
||||||
|
input="Show me all marketing reports",
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
include=["file_search_call.results"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.output[0].type == "file_search_call"
|
||||||
|
assert response.output[0].status == "completed"
|
||||||
|
assert response.output[0].results
|
||||||
|
# Should only return marketing files (not engineering or sales)
|
||||||
|
for result in response.output[0].results:
|
||||||
|
# Marketing files should have promotional/advertising content
|
||||||
|
assert "promotional" in result.text.lower() or "advertising" in result.text.lower()
|
||||||
|
# Ensure non-marketing categories are NOT returned
|
||||||
|
assert "technical" not in result.text.lower()
|
||||||
|
assert "revenue figures" not in result.text.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_response_file_search_filter_by_date_range(openai_client, model, vector_store_with_filtered_files):
|
||||||
|
"""Test file search with date range filter using compound AND."""
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "file_search",
|
||||||
|
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||||
|
"filters": {
|
||||||
|
"type": "and",
|
||||||
|
"filters": [
|
||||||
|
{
|
||||||
|
"type": "gte",
|
||||||
|
"key": "date",
|
||||||
|
"value": 1672531200, # Jan 1, 2023
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "lt",
|
||||||
|
"key": "date",
|
||||||
|
"value": 1680307200, # Apr 1, 2023
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = openai_client.responses.create(
|
||||||
|
model=model,
|
||||||
|
input="What happened in Q1 2023?",
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
include=["file_search_call.results"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.output[0].type == "file_search_call"
|
||||||
|
assert response.output[0].status == "completed"
|
||||||
|
assert response.output[0].results
|
||||||
|
# Should only return Q1 files (not Q2 or Q3)
|
||||||
|
for result in response.output[0].results:
|
||||||
|
assert "q1" in result.text.lower()
|
||||||
|
# Ensure non-Q1 quarters are NOT returned
|
||||||
|
assert "q2" not in result.text.lower()
|
||||||
|
assert "q3" not in result.text.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_response_file_search_filter_compound_and(openai_client, model, vector_store_with_filtered_files):
|
||||||
|
"""Test file search with compound AND filter (region AND category)."""
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "file_search",
|
||||||
|
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||||
|
"filters": {
|
||||||
|
"type": "and",
|
||||||
|
"filters": [
|
||||||
|
{"type": "eq", "key": "region", "value": "us"},
|
||||||
|
{"type": "eq", "key": "category", "value": "engineering"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = openai_client.responses.create(
|
||||||
|
model=model,
|
||||||
|
input="What are the engineering updates from the US?",
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
include=["file_search_call.results"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.output[0].type == "file_search_call"
|
||||||
|
assert response.output[0].status == "completed"
|
||||||
|
assert response.output[0].results
|
||||||
|
# Should only return US engineering files
|
||||||
|
assert len(response.output[0].results) >= 1
|
||||||
|
for result in response.output[0].results:
|
||||||
|
assert "us" in result.text.lower() and "technical" in result.text.lower()
|
||||||
|
# Ensure it's not from other regions or categories
|
||||||
|
assert "european" not in result.text.lower() and "asia" not in result.text.lower()
|
||||||
|
assert "promotional" not in result.text.lower() and "revenue" not in result.text.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_response_file_search_filter_compound_or(openai_client, model, vector_store_with_filtered_files):
|
||||||
|
"""Test file search with compound OR filter (marketing OR sales)."""
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "file_search",
|
||||||
|
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||||
|
"filters": {
|
||||||
|
"type": "or",
|
||||||
|
"filters": [
|
||||||
|
{"type": "eq", "key": "category", "value": "marketing"},
|
||||||
|
{"type": "eq", "key": "category", "value": "sales"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = openai_client.responses.create(
|
||||||
|
model=model,
|
||||||
|
input="Show me marketing and sales documents",
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
include=["file_search_call.results"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.output[0].type == "file_search_call"
|
||||||
|
assert response.output[0].status == "completed"
|
||||||
|
assert response.output[0].results
|
||||||
|
# Should return marketing and sales files, but NOT engineering
|
||||||
|
categories_found = set()
|
||||||
|
for result in response.output[0].results:
|
||||||
|
text_lower = result.text.lower()
|
||||||
|
if "promotional" in text_lower or "advertising" in text_lower:
|
||||||
|
categories_found.add("marketing")
|
||||||
|
if "revenue figures" in text_lower:
|
||||||
|
categories_found.add("sales")
|
||||||
|
# Ensure engineering files are NOT returned
|
||||||
|
assert "technical" not in text_lower, f"Engineering file should not be returned, but got: {result.text}"
|
||||||
|
|
||||||
|
# Verify we got at least one of the expected categories
|
||||||
|
assert len(categories_found) > 0, "Should have found at least one marketing or sales file"
|
||||||
|
assert categories_found.issubset({"marketing", "sales"}), f"Found unexpected categories: {categories_found}"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue