diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 96de04ec9..aef066f11 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -7047,6 +7047,9 @@ { "$ref": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall" }, + { + "$ref": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall" + }, { "$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall" }, @@ -7193,7 +7196,7 @@ "const": "file_search", "default": "file_search" }, - "vector_store_id": { + "vector_store_ids": { "type": "array", "items": { "type": "string" @@ -7217,7 +7220,7 @@ "additionalProperties": false, "required": [ "type", - "vector_store_id" + "vector_store_ids" ], "title": "OpenAIResponseInputToolFileSearch" }, @@ -7484,6 +7487,64 @@ ], "title": "OpenAIResponseOutputMessageContentOutputText" }, + "OpenAIResponseOutputMessageFileSearchToolCall": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "queries": { + "type": "array", + "items": { + "type": "string" + } + }, + "status": { + "type": "string" + }, + "type": { + "type": "string", + "const": "file_search_call", + "default": "file_search_call" + }, + "results": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + } + }, + "additionalProperties": false, + "required": [ + "id", + "queries", + "status", + "type" + ], + "title": "OpenAIResponseOutputMessageFileSearchToolCall" + }, "OpenAIResponseOutputMessageFunctionToolCall": { "type": "object", "properties": { @@ -7760,6 +7821,9 @@ { "$ref": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall" }, + { + "$ref": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall" + }, { "$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall" }, @@ -7775,6 +7839,7 @@ "mapping": { "message": "#/components/schemas/OpenAIResponseMessage", "web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall", + "file_search_call": "#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall", "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall", "mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall", "mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index b2fe870be..4154a430d 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -5021,6 +5021,7 @@ components: OpenAIResponseInput: oneOf: - $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' + - $ref: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' - $ref: '#/components/schemas/OpenAIResponseInputFunctionToolCallOutput' - $ref: '#/components/schemas/OpenAIResponseMessage' @@ -5115,7 +5116,7 @@ components: type: string const: file_search default: file_search - vector_store_id: + vector_store_ids: type: array items: type: string @@ -5132,7 +5133,7 @@ components: additionalProperties: false required: - type - - vector_store_id + - vector_store_ids title: OpenAIResponseInputToolFileSearch OpenAIResponseInputToolFunction: type: object @@ -5294,6 +5295,41 @@ components: - type title: >- OpenAIResponseOutputMessageContentOutputText + "OpenAIResponseOutputMessageFileSearchToolCall": + type: object + properties: + id: + type: string + queries: + type: array + items: + type: string + status: + type: string + type: + type: string + const: file_search_call + default: file_search_call + results: + type: array + items: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + additionalProperties: false + required: + - id + - queries + - status + - type + title: >- + OpenAIResponseOutputMessageFileSearchToolCall "OpenAIResponseOutputMessageFunctionToolCall": type: object properties: @@ -5491,6 +5527,7 @@ components: oneOf: - $ref: '#/components/schemas/OpenAIResponseMessage' - $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' + - $ref: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' @@ -5499,6 +5536,7 @@ components: mapping: message: '#/components/schemas/OpenAIResponseMessage' web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' + file_search_call: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall' function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index 35b3d5ace..bdd9c3e26 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -81,6 +81,15 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel): type: Literal["web_search_call"] = "web_search_call" +@json_schema_type +class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel): + id: str + queries: list[str] + status: str + type: Literal["file_search_call"] = "file_search_call" + results: list[dict[str, Any]] | None = None + + @json_schema_type class OpenAIResponseOutputMessageFunctionToolCall(BaseModel): call_id: str @@ -119,6 +128,7 @@ class OpenAIResponseOutputMessageMCPListTools(BaseModel): OpenAIResponseOutput = Annotated[ OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall + | OpenAIResponseOutputMessageFileSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall | OpenAIResponseOutputMessageMCPCall | OpenAIResponseOutputMessageMCPListTools, @@ -362,6 +372,7 @@ class OpenAIResponseInputFunctionToolCallOutput(BaseModel): OpenAIResponseInput = Annotated[ # Responses API allows output messages to be passed in as input OpenAIResponseOutputMessageWebSearchToolCall + | OpenAIResponseOutputMessageFileSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall | OpenAIResponseInputFunctionToolCallOutput | @@ -397,9 +408,9 @@ class FileSearchRankingOptions(BaseModel): @json_schema_type class OpenAIResponseInputToolFileSearch(BaseModel): type: Literal["file_search"] = "file_search" - vector_store_id: list[str] + vector_store_ids: list[str] ranking_options: FileSearchRankingOptions | None = None - # TODO: add filters + # TODO: add filters, max_num_results class ApprovalFilter(BaseModel): diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 0ff6dc2c5..963dd1ddd 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -24,6 +24,7 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseInputMessageContentImage, OpenAIResponseInputMessageContentText, OpenAIResponseInputTool, + OpenAIResponseInputToolFileSearch, OpenAIResponseInputToolMCP, OpenAIResponseMessage, OpenAIResponseObject, @@ -34,6 +35,7 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseOutput, OpenAIResponseOutputMessageContent, OpenAIResponseOutputMessageContentOutputText, + OpenAIResponseOutputMessageFileSearchToolCall, OpenAIResponseOutputMessageFunctionToolCall, OpenAIResponseOutputMessageMCPListTools, OpenAIResponseOutputMessageWebSearchToolCall, @@ -198,7 +200,8 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel): class ChatCompletionContext(BaseModel): model: str messages: list[OpenAIMessageParam] - tools: list[ChatCompletionToolParam] | None = None + response_tools: list[OpenAIResponseInputTool] | None = None + chat_tools: list[ChatCompletionToolParam] | None = None mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] temperature: float | None response_format: OpenAIResponseFormatParam @@ -388,7 +391,8 @@ class OpenAIResponsesImpl: ctx = ChatCompletionContext( model=model, messages=messages, - tools=chat_tools, + response_tools=tools, + chat_tools=chat_tools, mcp_tool_to_server=mcp_tool_to_server, temperature=temperature, response_format=response_format, @@ -417,7 +421,7 @@ class OpenAIResponsesImpl: completion_result = await self.inference_api.openai_chat_completion( model=ctx.model, messages=messages, - tools=ctx.tools, + tools=ctx.chat_tools, stream=True, temperature=ctx.temperature, response_format=ctx.response_format, @@ -606,6 +610,12 @@ class OpenAIResponsesImpl: if not tool: raise ValueError(f"Tool {tool_name} not found") chat_tools.append(make_openai_tool(tool_name, tool)) + elif input_tool.type == "file_search": + tool_name = "knowledge_search" + tool = await self.tool_groups_api.get_tool(tool_name) + if not tool: + raise ValueError(f"Tool {tool_name} not found") + chat_tools.append(make_openai_tool(tool_name, tool)) elif input_tool.type == "mcp": always_allowed = None never_allowed = None @@ -667,6 +677,7 @@ class OpenAIResponsesImpl: tool_call_id = tool_call.id function = tool_call.function + tool_kwargs = json.loads(function.arguments) if function.arguments else {} if not function or not tool_call_id or not function.name: return None, None @@ -680,12 +691,18 @@ class OpenAIResponsesImpl: endpoint=mcp_tool.server_url, headers=mcp_tool.headers or {}, tool_name=function.name, - kwargs=json.loads(function.arguments) if function.arguments else {}, + kwargs=tool_kwargs, ) else: + if function.name == "knowledge_search": + response_file_search_tool = next( + t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch) + ) + if response_file_search_tool: + tool_kwargs["vector_db_ids"] = response_file_search_tool.vector_store_ids result = await self.tool_runtime_api.invoke_tool( tool_name=function.name, - kwargs=json.loads(function.arguments) if function.arguments else {}, + kwargs=tool_kwargs, ) except Exception as e: error_exc = e @@ -713,6 +730,27 @@ class OpenAIResponsesImpl: ) if error_exc or (result.error_code and result.error_code > 0) or result.error_message: message.status = "failed" + elif function.name == "knowledge_search": + message = OpenAIResponseOutputMessageFileSearchToolCall( + id=tool_call_id, + queries=[tool_kwargs.get("query", "")], + status="completed", + ) + if "document_ids" in result.metadata: + message.results = [] + for i, doc_id in enumerate(result.metadata["document_ids"]): + text = result.metadata["chunks"][i] if "chunks" in result.metadata else None + score = result.metadata["scores"][i] if "scores" in result.metadata else None + message.results.append( + { + "file_id": doc_id, + "filename": doc_id, + "text": text, + "score": score, + } + ) + if error_exc or (result.error_code and result.error_code > 0) or result.error_message: + message.status = "failed" else: raise ValueError(f"Unknown tool {function.name} called") diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 4776d47d0..e15d067a7 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -170,6 +170,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti content=picked, metadata={ "document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]], + "chunks": [c.content for c in chunks[: len(picked)]], + "scores": scores[: len(picked)], }, ) diff --git a/tests/verifications/openai_api/fixtures/test_cases/responses.yaml b/tests/verifications/openai_api/fixtures/test_cases/responses.yaml index 4d6c19b59..7115e4b50 100644 --- a/tests/verifications/openai_api/fixtures/test_cases/responses.yaml +++ b/tests/verifications/openai_api/fixtures/test_cases/responses.yaml @@ -31,6 +31,18 @@ test_response_web_search: search_context_size: "low" output: "128" +test_response_file_search: + test_name: test_response_file_search + test_params: + case: + - case_id: "llama_experts" + input: "How many experts does the Llama 4 Maverick model have?" + tools: + - type: file_search + vector_store_ids: + - test_vector_store + output: "128" + test_response_mcp_tool: test_name: test_response_mcp_tool test_params: diff --git a/tests/verifications/openai_api/test_responses.py b/tests/verifications/openai_api/test_responses.py index 28020d3b1..86b267fac 100644 --- a/tests/verifications/openai_api/test_responses.py +++ b/tests/verifications/openai_api/test_responses.py @@ -9,6 +9,7 @@ import json import httpx import openai import pytest +from llama_stack_client import LlamaStackClient from llama_stack import LlamaStackAsLibraryClient from llama_stack.distribution.datatypes import AuthenticationRequiredError @@ -258,6 +259,62 @@ def test_response_non_streaming_web_search(request, openai_client, model, provid assert case["output"].lower() in response.output_text.lower().strip() +@pytest.mark.parametrize( + "case", + responses_test_cases["test_response_file_search"]["test_params"]["case"], + ids=case_id_generator, +) +def test_response_non_streaming_file_search( + base_url, request, openai_client, model, provider, verification_config, case +): + test_name_base = get_base_test_name(request) + if should_skip_test(verification_config, provider, model, test_name_base): + pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") + + lls_client = LlamaStackClient(base_url=base_url.replace("/v1/openai/v1", "")) + vector_db_id = "test_vector_store" + + # Ensure the test starts from a clean vector store + try: + lls_client.vector_dbs.unregister(vector_db_id=vector_db_id) + except Exception: + pass + + lls_client.vector_dbs.register( + vector_db_id=vector_db_id, + embedding_model="all-MiniLM-L6-v2", + ) + doc_content = "Llama 4 Maverick has 128 experts" + chunks = [ + { + "content": doc_content, + "mime_type": "text/plain", + "metadata": { + "document_id": "doc1", + }, + }, + ] + lls_client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks) + + response = openai_client.responses.create( + model=model, + input=case["input"], + tools=case["tools"], + stream=False, + ) + assert len(response.output) > 1 + assert response.output[0].type == "file_search_call" + assert response.output[0].status == "completed" + assert response.output[0].queries # ensure it's some non-empty list + assert response.output[0].results[0].text == doc_content + assert response.output[0].results[0].score > 0 + assert response.output[1].type == "message" + assert response.output[1].status == "completed" + assert response.output[1].role == "assistant" + assert len(response.output[1].content) > 0 + assert case["output"].lower() in response.output_text.lower().strip() + + @pytest.mark.parametrize( "case", responses_test_cases["test_response_mcp_tool"]["test_params"]["case"],