From fa344683080c54ee3529e3f7c7b4ff39c7d47b89 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Tue, 10 Jun 2025 12:42:09 -0400 Subject: [PATCH] feat: File search tool for Responses API This is an initial working prototype of wiring up the `file_search` builtin tool for the Responses API to our existing rag knowledge search tool. I stubbed in a new test (that uses a hardcoded url hybrid of the OpenAI and Llama Stack clients for now, only until we finish landing the vector store APIs and insertion support). Note that this is currently under tests/verification only because it sometimes flakes with tool calling of the small Llama-3.2-3B model we run in CI (and that I use as an example below). We'd want to make the test a bit more robust in some way if we moved this over to tests/integration and ran it in CI. ``` ollama run llama3.2:3b INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" \ llama stack run ./llama_stack/templates/ollama/run.yaml \ --image-type venv \ --env OLLAMA_URL="http://0.0.0.0:11434" pytest -sv 'tests/verifications/openai_api/test_responses.py::test_response_non_streaming_file_search' \ --base-url=http://localhost:8321/v1/openai/v1 \ --model meta-llama/Llama-3.2-3B-Instruct ``` Signed-off-by: Ben Browning --- docs/_static/llama-stack-spec.html | 69 ++++++++++++++++++- docs/_static/llama-stack-spec.yaml | 42 ++++++++++- llama_stack/apis/agents/openai_responses.py | 15 +++- .../agents/meta_reference/openai_responses.py | 48 +++++++++++-- .../inline/tool_runtime/rag/memory.py | 2 + .../fixtures/test_cases/responses.yaml | 12 ++++ .../openai_api/test_responses.py | 57 +++++++++++++++ 7 files changed, 234 insertions(+), 11 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 96de04ec9..aef066f11 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -7047,6 +7047,9 @@ { "$ref": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall" }, + { + "$ref": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall" + }, { "$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall" }, @@ -7193,7 +7196,7 @@ "const": "file_search", "default": "file_search" }, - "vector_store_id": { + "vector_store_ids": { "type": "array", "items": { "type": "string" @@ -7217,7 +7220,7 @@ "additionalProperties": false, "required": [ "type", - "vector_store_id" + "vector_store_ids" ], "title": "OpenAIResponseInputToolFileSearch" }, @@ -7484,6 +7487,64 @@ ], "title": "OpenAIResponseOutputMessageContentOutputText" }, + "OpenAIResponseOutputMessageFileSearchToolCall": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "queries": { + "type": "array", + "items": { + "type": "string" + } + }, + "status": { + "type": "string" + }, + "type": { + "type": "string", + "const": "file_search_call", + "default": "file_search_call" + }, + "results": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + } + }, + "additionalProperties": false, + "required": [ + "id", + "queries", + "status", + "type" + ], + "title": "OpenAIResponseOutputMessageFileSearchToolCall" + }, "OpenAIResponseOutputMessageFunctionToolCall": { "type": "object", "properties": { @@ -7760,6 +7821,9 @@ { "$ref": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall" }, + { + "$ref": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall" + }, { "$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall" }, @@ -7775,6 +7839,7 @@ "mapping": { "message": "#/components/schemas/OpenAIResponseMessage", "web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall", + "file_search_call": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall", "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall", "mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall", "mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index b2fe870be..4154a430d 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -5021,6 +5021,7 @@ components: OpenAIResponseInput: oneOf: - $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' + - $ref: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' - $ref: '#/components/schemas/OpenAIResponseInputFunctionToolCallOutput' - $ref: '#/components/schemas/OpenAIResponseMessage' @@ -5115,7 +5116,7 @@ components: type: string const: file_search default: file_search - vector_store_id: + vector_store_ids: type: array items: type: string @@ -5132,7 +5133,7 @@ components: additionalProperties: false required: - type - - vector_store_id + - vector_store_ids title: OpenAIResponseInputToolFileSearch OpenAIResponseInputToolFunction: type: object @@ -5294,6 +5295,41 @@ components: - type title: >- OpenAIResponseOutputMessageContentOutputText + "OpenAIResponseOutputMessageFileSearchToolCall": + type: object + properties: + id: + type: string + queries: + type: array + items: + type: string + status: + type: string + type: + type: string + const: file_search_call + default: file_search_call + results: + type: array + items: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + additionalProperties: false + required: + - id + - queries + - status + - type + title: >- + OpenAIResponseOutputMessageFileSearchToolCall "OpenAIResponseOutputMessageFunctionToolCall": type: object properties: @@ -5491,6 +5527,7 @@ components: oneOf: - $ref: '#/components/schemas/OpenAIResponseMessage' - $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' + - $ref: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' @@ -5499,6 +5536,7 @@ components: mapping: message: '#/components/schemas/OpenAIResponseMessage' web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' + file_search_call: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall' function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index 35b3d5ace..bdd9c3e26 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -81,6 +81,15 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel): type: Literal["web_search_call"] = "web_search_call" +@json_schema_type +class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel): + id: str + queries: list[str] + status: str + type: Literal["file_search_call"] = "file_search_call" + results: list[dict[str, Any]] | None = None + + @json_schema_type class OpenAIResponseOutputMessageFunctionToolCall(BaseModel): call_id: str @@ -119,6 +128,7 @@ class OpenAIResponseOutputMessageMCPListTools(BaseModel): OpenAIResponseOutput = Annotated[ OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall + | OpenAIResponseOutputMessageFileSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall | OpenAIResponseOutputMessageMCPCall | OpenAIResponseOutputMessageMCPListTools, @@ -362,6 +372,7 @@ class OpenAIResponseInputFunctionToolCallOutput(BaseModel): OpenAIResponseInput = Annotated[ # Responses API allows output messages to be passed in as input OpenAIResponseOutputMessageWebSearchToolCall + | OpenAIResponseOutputMessageFileSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall | OpenAIResponseInputFunctionToolCallOutput | @@ -397,9 +408,9 @@ class FileSearchRankingOptions(BaseModel): @json_schema_type class OpenAIResponseInputToolFileSearch(BaseModel): type: Literal["file_search"] = "file_search" - vector_store_id: list[str] + vector_store_ids: list[str] ranking_options: FileSearchRankingOptions | None = None - # TODO: add filters + # TODO: add filters, max_num_results class ApprovalFilter(BaseModel): diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 0ff6dc2c5..963dd1ddd 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -24,6 +24,7 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseInputMessageContentImage, OpenAIResponseInputMessageContentText, OpenAIResponseInputTool, + OpenAIResponseInputToolFileSearch, OpenAIResponseInputToolMCP, OpenAIResponseMessage, OpenAIResponseObject, @@ -34,6 +35,7 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseOutput, OpenAIResponseOutputMessageContent, OpenAIResponseOutputMessageContentOutputText, + OpenAIResponseOutputMessageFileSearchToolCall, OpenAIResponseOutputMessageFunctionToolCall, OpenAIResponseOutputMessageMCPListTools, OpenAIResponseOutputMessageWebSearchToolCall, @@ -198,7 +200,8 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel): class ChatCompletionContext(BaseModel): model: str messages: list[OpenAIMessageParam] - tools: list[ChatCompletionToolParam] | None = None + response_tools: list[OpenAIResponseInputTool] | None = None + chat_tools: list[ChatCompletionToolParam] | None = None mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] temperature: float | None response_format: OpenAIResponseFormatParam @@ -388,7 +391,8 @@ class OpenAIResponsesImpl: ctx = ChatCompletionContext( model=model, messages=messages, - tools=chat_tools, + response_tools=tools, + chat_tools=chat_tools, mcp_tool_to_server=mcp_tool_to_server, temperature=temperature, response_format=response_format, @@ -417,7 +421,7 @@ class OpenAIResponsesImpl: completion_result = await self.inference_api.openai_chat_completion( model=ctx.model, messages=messages, - tools=ctx.tools, + tools=ctx.chat_tools, stream=True, temperature=ctx.temperature, response_format=ctx.response_format, @@ -606,6 +610,12 @@ class OpenAIResponsesImpl: if not tool: raise ValueError(f"Tool {tool_name} not found") chat_tools.append(make_openai_tool(tool_name, tool)) + elif input_tool.type == "file_search": + tool_name = "knowledge_search" + tool = await self.tool_groups_api.get_tool(tool_name) + if not tool: + raise ValueError(f"Tool {tool_name} not found") + chat_tools.append(make_openai_tool(tool_name, tool)) elif input_tool.type == "mcp": always_allowed = None never_allowed = None @@ -667,6 +677,7 @@ class OpenAIResponsesImpl: tool_call_id = tool_call.id function = tool_call.function + tool_kwargs = json.loads(function.arguments) if function.arguments else {} if not function or not tool_call_id or not function.name: return None, None @@ -680,12 +691,18 @@ class OpenAIResponsesImpl: endpoint=mcp_tool.server_url, headers=mcp_tool.headers or {}, tool_name=function.name, - kwargs=json.loads(function.arguments) if function.arguments else {}, + kwargs=tool_kwargs, ) else: + if function.name == "knowledge_search": + response_file_search_tool = next( + t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch) + ) + if response_file_search_tool: + tool_kwargs["vector_db_ids"] = response_file_search_tool.vector_store_ids result = await self.tool_runtime_api.invoke_tool( tool_name=function.name, - kwargs=json.loads(function.arguments) if function.arguments else {}, + kwargs=tool_kwargs, ) except Exception as e: error_exc = e @@ -713,6 +730,27 @@ class OpenAIResponsesImpl: ) if error_exc or (result.error_code and result.error_code > 0) or result.error_message: message.status = "failed" + elif function.name == "knowledge_search": + message = OpenAIResponseOutputMessageFileSearchToolCall( + id=tool_call_id, + queries=[tool_kwargs.get("query", "")], + status="completed", + ) + if "document_ids" in result.metadata: + message.results = [] + for i, doc_id in enumerate(result.metadata["document_ids"]): + text = result.metadata["chunks"][i] if "chunks" in result.metadata else None + score = result.metadata["scores"][i] if "scores" in result.metadata else None + message.results.append( + { + "file_id": doc_id, + "filename": doc_id, + "text": text, + "score": score, + } + ) + if error_exc or (result.error_code and result.error_code > 0) or result.error_message: + message.status = "failed" else: raise ValueError(f"Unknown tool {function.name} called") diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 4776d47d0..e15d067a7 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -170,6 +170,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti content=picked, metadata={ "document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]], + "chunks": [c.content for c in chunks[: len(picked)]], + "scores": scores[: len(picked)], }, ) diff --git a/tests/verifications/openai_api/fixtures/test_cases/responses.yaml b/tests/verifications/openai_api/fixtures/test_cases/responses.yaml index 4d6c19b59..7115e4b50 100644 --- a/tests/verifications/openai_api/fixtures/test_cases/responses.yaml +++ b/tests/verifications/openai_api/fixtures/test_cases/responses.yaml @@ -31,6 +31,18 @@ test_response_web_search: search_context_size: "low" output: "128" +test_response_file_search: + test_name: test_response_file_search + test_params: + case: + - case_id: "llama_experts" + input: "How many experts does the Llama 4 Maverick model have?" + tools: + - type: file_search + vector_store_ids: + - test_vector_store + output: "128" + test_response_mcp_tool: test_name: test_response_mcp_tool test_params: diff --git a/tests/verifications/openai_api/test_responses.py b/tests/verifications/openai_api/test_responses.py index 28020d3b1..86b267fac 100644 --- a/tests/verifications/openai_api/test_responses.py +++ b/tests/verifications/openai_api/test_responses.py @@ -9,6 +9,7 @@ import json import httpx import openai import pytest +from llama_stack_client import LlamaStackClient from llama_stack import LlamaStackAsLibraryClient from llama_stack.distribution.datatypes import AuthenticationRequiredError @@ -258,6 +259,62 @@ def test_response_non_streaming_web_search(request, openai_client, model, provid assert case["output"].lower() in response.output_text.lower().strip() +@pytest.mark.parametrize( + "case", + responses_test_cases["test_response_file_search"]["test_params"]["case"], + ids=case_id_generator, +) +def test_response_non_streaming_file_search( + base_url, request, openai_client, model, provider, verification_config, case +): + test_name_base = get_base_test_name(request) + if should_skip_test(verification_config, provider, model, test_name_base): + pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") + + lls_client = LlamaStackClient(base_url=base_url.replace("/v1/openai/v1", "")) + vector_db_id = "test_vector_store" + + # Ensure the test starts from a clean vector store + try: + lls_client.vector_dbs.unregister(vector_db_id=vector_db_id) + except Exception: + pass + + lls_client.vector_dbs.register( + vector_db_id=vector_db_id, + embedding_model="all-MiniLM-L6-v2", + ) + doc_content = "Llama 4 Maverick has 128 experts" + chunks = [ + { + "content": doc_content, + "mime_type": "text/plain", + "metadata": { + "document_id": "doc1", + }, + }, + ] + lls_client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks) + + response = openai_client.responses.create( + model=model, + input=case["input"], + tools=case["tools"], + stream=False, + ) + assert len(response.output) > 1 + assert response.output[0].type == "file_search_call" + assert response.output[0].status == "completed" + assert response.output[0].queries # ensure it's some non-empty list + assert response.output[0].results[0].text == doc_content + assert response.output[0].results[0].score > 0 + assert response.output[1].type == "message" + assert response.output[1].status == "completed" + assert response.output[1].role == "assistant" + assert len(response.output[1].content) > 0 + assert case["output"].lower() in response.output_text.lower().strip() + + @pytest.mark.parametrize( "case", responses_test_cases["test_response_mcp_tool"]["test_params"]["case"],