feat: File search tool for Responses API

This is an initial working prototype of wiring up the `file_search` builtin
tool for the Responses API to our existing rag knowledge search tool.

I stubbed in a new test (that uses a hardcoded url hybrid of the OpenAI
and Llama Stack clients for now, only until we finish landing the vector
store APIs and insertion support).

Note that this is currently under tests/verification only because it
sometimes flakes with tool calling of the small Llama-3.2-3B model we run
in CI (and that I use as an example below). We'd want to make the test a
bit more robust in some way if we moved this over to tests/integration and
ran it in CI.

```
ollama run llama3.2:3b

INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" \
llama stack run ./llama_stack/templates/ollama/run.yaml \
  --image-type venv \
  --env OLLAMA_URL="http://0.0.0.0:11434"

pytest -sv 'tests/verifications/openai_api/test_responses.py::test_response_non_streaming_file_search' \
  --base-url=http://localhost:8321/v1/openai/v1 \
  --model meta-llama/Llama-3.2-3B-Instruct
```

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-06-10 12:42:09 -04:00
parent e2e15ebb6c
commit fa34468308
7 changed files with 234 additions and 11 deletions

View file

@ -7047,6 +7047,9 @@
{ {
"$ref": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall" "$ref": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall"
}, },
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall"
},
{ {
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall" "$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
}, },
@ -7193,7 +7196,7 @@
"const": "file_search", "const": "file_search",
"default": "file_search" "default": "file_search"
}, },
"vector_store_id": { "vector_store_ids": {
"type": "array", "type": "array",
"items": { "items": {
"type": "string" "type": "string"
@ -7217,7 +7220,7 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"type", "type",
"vector_store_id" "vector_store_ids"
], ],
"title": "OpenAIResponseInputToolFileSearch" "title": "OpenAIResponseInputToolFileSearch"
}, },
@ -7484,6 +7487,64 @@
], ],
"title": "OpenAIResponseOutputMessageContentOutputText" "title": "OpenAIResponseOutputMessageContentOutputText"
}, },
"OpenAIResponseOutputMessageFileSearchToolCall": {
"type": "object",
"properties": {
"id": {
"type": "string"
},
"queries": {
"type": "array",
"items": {
"type": "string"
}
},
"status": {
"type": "string"
},
"type": {
"type": "string",
"const": "file_search_call",
"default": "file_search_call"
},
"results": {
"type": "array",
"items": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
}
},
"additionalProperties": false,
"required": [
"id",
"queries",
"status",
"type"
],
"title": "OpenAIResponseOutputMessageFileSearchToolCall"
},
"OpenAIResponseOutputMessageFunctionToolCall": { "OpenAIResponseOutputMessageFunctionToolCall": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -7760,6 +7821,9 @@
{ {
"$ref": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall" "$ref": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall"
}, },
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall"
},
{ {
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall" "$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
}, },
@ -7775,6 +7839,7 @@
"mapping": { "mapping": {
"message": "#/components/schemas/OpenAIResponseMessage", "message": "#/components/schemas/OpenAIResponseMessage",
"web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall", "web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall",
"file_search_call": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall",
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall", "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall",
"mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall", "mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall",
"mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" "mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"

View file

@ -5021,6 +5021,7 @@ components:
OpenAIResponseInput: OpenAIResponseInput:
oneOf: oneOf:
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
- $ref: '#/components/schemas/OpenAIResponseInputFunctionToolCallOutput' - $ref: '#/components/schemas/OpenAIResponseInputFunctionToolCallOutput'
- $ref: '#/components/schemas/OpenAIResponseMessage' - $ref: '#/components/schemas/OpenAIResponseMessage'
@ -5115,7 +5116,7 @@ components:
type: string type: string
const: file_search const: file_search
default: file_search default: file_search
vector_store_id: vector_store_ids:
type: array type: array
items: items:
type: string type: string
@ -5132,7 +5133,7 @@ components:
additionalProperties: false additionalProperties: false
required: required:
- type - type
- vector_store_id - vector_store_ids
title: OpenAIResponseInputToolFileSearch title: OpenAIResponseInputToolFileSearch
OpenAIResponseInputToolFunction: OpenAIResponseInputToolFunction:
type: object type: object
@ -5294,6 +5295,41 @@ components:
- type - type
title: >- title: >-
OpenAIResponseOutputMessageContentOutputText OpenAIResponseOutputMessageContentOutputText
"OpenAIResponseOutputMessageFileSearchToolCall":
type: object
properties:
id:
type: string
queries:
type: array
items:
type: string
status:
type: string
type:
type: string
const: file_search_call
default: file_search_call
results:
type: array
items:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false
required:
- id
- queries
- status
- type
title: >-
OpenAIResponseOutputMessageFileSearchToolCall
"OpenAIResponseOutputMessageFunctionToolCall": "OpenAIResponseOutputMessageFunctionToolCall":
type: object type: object
properties: properties:
@ -5491,6 +5527,7 @@ components:
oneOf: oneOf:
- $ref: '#/components/schemas/OpenAIResponseMessage' - $ref: '#/components/schemas/OpenAIResponseMessage'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
@ -5499,6 +5536,7 @@ components:
mapping: mapping:
message: '#/components/schemas/OpenAIResponseMessage' message: '#/components/schemas/OpenAIResponseMessage'
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
file_search_call: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall'
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'

View file

@ -81,6 +81,15 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
type: Literal["web_search_call"] = "web_search_call" type: Literal["web_search_call"] = "web_search_call"
@json_schema_type
class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
id: str
queries: list[str]
status: str
type: Literal["file_search_call"] = "file_search_call"
results: list[dict[str, Any]] | None = None
@json_schema_type @json_schema_type
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel): class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
call_id: str call_id: str
@ -119,6 +128,7 @@ class OpenAIResponseOutputMessageMCPListTools(BaseModel):
OpenAIResponseOutput = Annotated[ OpenAIResponseOutput = Annotated[
OpenAIResponseMessage OpenAIResponseMessage
| OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall | OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseOutputMessageMCPCall | OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools, | OpenAIResponseOutputMessageMCPListTools,
@ -362,6 +372,7 @@ class OpenAIResponseInputFunctionToolCallOutput(BaseModel):
OpenAIResponseInput = Annotated[ OpenAIResponseInput = Annotated[
# Responses API allows output messages to be passed in as input # Responses API allows output messages to be passed in as input
OpenAIResponseOutputMessageWebSearchToolCall OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall | OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseInputFunctionToolCallOutput | OpenAIResponseInputFunctionToolCallOutput
| |
@ -397,9 +408,9 @@ class FileSearchRankingOptions(BaseModel):
@json_schema_type @json_schema_type
class OpenAIResponseInputToolFileSearch(BaseModel): class OpenAIResponseInputToolFileSearch(BaseModel):
type: Literal["file_search"] = "file_search" type: Literal["file_search"] = "file_search"
vector_store_id: list[str] vector_store_ids: list[str]
ranking_options: FileSearchRankingOptions | None = None ranking_options: FileSearchRankingOptions | None = None
# TODO: add filters # TODO: add filters, max_num_results
class ApprovalFilter(BaseModel): class ApprovalFilter(BaseModel):

View file

@ -24,6 +24,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessageContentImage, OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText, OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolMCP, OpenAIResponseInputToolMCP,
OpenAIResponseMessage, OpenAIResponseMessage,
OpenAIResponseObject, OpenAIResponseObject,
@ -34,6 +35,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutput, OpenAIResponseOutput,
OpenAIResponseOutputMessageContent, OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText, OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFunctionToolCall, OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPListTools, OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall, OpenAIResponseOutputMessageWebSearchToolCall,
@ -198,7 +200,8 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
class ChatCompletionContext(BaseModel): class ChatCompletionContext(BaseModel):
model: str model: str
messages: list[OpenAIMessageParam] messages: list[OpenAIMessageParam]
tools: list[ChatCompletionToolParam] | None = None response_tools: list[OpenAIResponseInputTool] | None = None
chat_tools: list[ChatCompletionToolParam] | None = None
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
temperature: float | None temperature: float | None
response_format: OpenAIResponseFormatParam response_format: OpenAIResponseFormatParam
@ -388,7 +391,8 @@ class OpenAIResponsesImpl:
ctx = ChatCompletionContext( ctx = ChatCompletionContext(
model=model, model=model,
messages=messages, messages=messages,
tools=chat_tools, response_tools=tools,
chat_tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server, mcp_tool_to_server=mcp_tool_to_server,
temperature=temperature, temperature=temperature,
response_format=response_format, response_format=response_format,
@ -417,7 +421,7 @@ class OpenAIResponsesImpl:
completion_result = await self.inference_api.openai_chat_completion( completion_result = await self.inference_api.openai_chat_completion(
model=ctx.model, model=ctx.model,
messages=messages, messages=messages,
tools=ctx.tools, tools=ctx.chat_tools,
stream=True, stream=True,
temperature=ctx.temperature, temperature=ctx.temperature,
response_format=ctx.response_format, response_format=ctx.response_format,
@ -606,6 +610,12 @@ class OpenAIResponsesImpl:
if not tool: if not tool:
raise ValueError(f"Tool {tool_name} not found") raise ValueError(f"Tool {tool_name} not found")
chat_tools.append(make_openai_tool(tool_name, tool)) chat_tools.append(make_openai_tool(tool_name, tool))
elif input_tool.type == "file_search":
tool_name = "knowledge_search"
tool = await self.tool_groups_api.get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")
chat_tools.append(make_openai_tool(tool_name, tool))
elif input_tool.type == "mcp": elif input_tool.type == "mcp":
always_allowed = None always_allowed = None
never_allowed = None never_allowed = None
@ -667,6 +677,7 @@ class OpenAIResponsesImpl:
tool_call_id = tool_call.id tool_call_id = tool_call.id
function = tool_call.function function = tool_call.function
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
if not function or not tool_call_id or not function.name: if not function or not tool_call_id or not function.name:
return None, None return None, None
@ -680,12 +691,18 @@ class OpenAIResponsesImpl:
endpoint=mcp_tool.server_url, endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {}, headers=mcp_tool.headers or {},
tool_name=function.name, tool_name=function.name,
kwargs=json.loads(function.arguments) if function.arguments else {}, kwargs=tool_kwargs,
) )
else: else:
if function.name == "knowledge_search":
response_file_search_tool = next(
t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)
)
if response_file_search_tool:
tool_kwargs["vector_db_ids"] = response_file_search_tool.vector_store_ids
result = await self.tool_runtime_api.invoke_tool( result = await self.tool_runtime_api.invoke_tool(
tool_name=function.name, tool_name=function.name,
kwargs=json.loads(function.arguments) if function.arguments else {}, kwargs=tool_kwargs,
) )
except Exception as e: except Exception as e:
error_exc = e error_exc = e
@ -713,6 +730,27 @@ class OpenAIResponsesImpl:
) )
if error_exc or (result.error_code and result.error_code > 0) or result.error_message: if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
message.status = "failed" message.status = "failed"
elif function.name == "knowledge_search":
message = OpenAIResponseOutputMessageFileSearchToolCall(
id=tool_call_id,
queries=[tool_kwargs.get("query", "")],
status="completed",
)
if "document_ids" in result.metadata:
message.results = []
for i, doc_id in enumerate(result.metadata["document_ids"]):
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
score = result.metadata["scores"][i] if "scores" in result.metadata else None
message.results.append(
{
"file_id": doc_id,
"filename": doc_id,
"text": text,
"score": score,
}
)
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
message.status = "failed"
else: else:
raise ValueError(f"Unknown tool {function.name} called") raise ValueError(f"Unknown tool {function.name} called")

View file

@ -170,6 +170,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
content=picked, content=picked,
metadata={ metadata={
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]], "document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
"chunks": [c.content for c in chunks[: len(picked)]],
"scores": scores[: len(picked)],
}, },
) )

View file

@ -31,6 +31,18 @@ test_response_web_search:
search_context_size: "low" search_context_size: "low"
output: "128" output: "128"
test_response_file_search:
test_name: test_response_file_search
test_params:
case:
- case_id: "llama_experts"
input: "How many experts does the Llama 4 Maverick model have?"
tools:
- type: file_search
vector_store_ids:
- test_vector_store
output: "128"
test_response_mcp_tool: test_response_mcp_tool:
test_name: test_response_mcp_tool test_name: test_response_mcp_tool
test_params: test_params:

View file

@ -9,6 +9,7 @@ import json
import httpx import httpx
import openai import openai
import pytest import pytest
from llama_stack_client import LlamaStackClient
from llama_stack import LlamaStackAsLibraryClient from llama_stack import LlamaStackAsLibraryClient
from llama_stack.distribution.datatypes import AuthenticationRequiredError from llama_stack.distribution.datatypes import AuthenticationRequiredError
@ -258,6 +259,62 @@ def test_response_non_streaming_web_search(request, openai_client, model, provid
assert case["output"].lower() in response.output_text.lower().strip() assert case["output"].lower() in response.output_text.lower().strip()
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_file_search"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_non_streaming_file_search(
base_url, request, openai_client, model, provider, verification_config, case
):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
lls_client = LlamaStackClient(base_url=base_url.replace("/v1/openai/v1", ""))
vector_db_id = "test_vector_store"
# Ensure the test starts from a clean vector store
try:
lls_client.vector_dbs.unregister(vector_db_id=vector_db_id)
except Exception:
pass
lls_client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
)
doc_content = "Llama 4 Maverick has 128 experts"
chunks = [
{
"content": doc_content,
"mime_type": "text/plain",
"metadata": {
"document_id": "doc1",
},
},
]
lls_client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks)
response = openai_client.responses.create(
model=model,
input=case["input"],
tools=case["tools"],
stream=False,
)
assert len(response.output) > 1
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].queries # ensure it's some non-empty list
assert response.output[0].results[0].text == doc_content
assert response.output[0].results[0].score > 0
assert response.output[1].type == "message"
assert response.output[1].status == "completed"
assert response.output[1].role == "assistant"
assert len(response.output[1].content) > 0
assert case["output"].lower() in response.output_text.lower().strip()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"case", "case",
responses_test_cases["test_response_mcp_tool"]["test_params"]["case"], responses_test_cases["test_response_mcp_tool"]["test_params"]["case"],