mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: File search tool for Responses API (#2426)
# What does this PR do? This is an initial working prototype of wiring up the `file_search` builtin tool for the Responses API to our existing rag knowledge search tool. This is me seeing what I could pull together on top of the bits we already have merged. This may not be the ideal way to implement this, and things like how I shuffle the vector store ids from the original response API tool request to the actual tool execution feel a bit hacky (grep for `tool_kwargs["vector_db_ids"]` in `_execute_tool_call` to see what I mean). ## Test Plan I stubbed in some new tests to exercise this using text and pdf documents. Note that this is currently under tests/verification only because it sometimes flakes with tool calling of the small Llama-3.2-3B model we run in CI (and that I use as an example below). We'd want to make the test a bit more robust in some way if we moved this over to tests/integration and ran it in CI. ### OpenAI SaaS (to verify test correctness) ``` pytest -sv tests/verifications/openai_api/test_responses.py \ -k 'file_search' \ --base-url=https://api.openai.com/v1 \ --model=gpt-4o ``` ### Fireworks with faiss vector store ``` llama stack run llama_stack/templates/fireworks/run.yaml pytest -sv tests/verifications/openai_api/test_responses.py \ -k 'file_search' \ --base-url=http://localhost:8321/v1/openai/v1 \ --model=meta-llama/Llama-3.3-70B-Instruct ``` ### Ollama with faiss vector store This sometimes flakes on Ollama because the quantized small model doesn't always choose to call the tool to answer the user's question. But, it often works. ``` ollama run llama3.2:3b INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" \ llama stack run ./llama_stack/templates/ollama/run.yaml \ --image-type venv \ --env OLLAMA_URL="http://0.0.0.0:11434" pytest -sv tests/verifications/openai_api/test_responses.py \ -k'file_search' \ --base-url=http://localhost:8321/v1/openai/v1 \ --model=meta-llama/Llama-3.2-3B-Instruct ``` ### OpenAI provider with sqlite-vec vector store ``` llama stack run ./llama_stack/templates/starter/run.yaml --image-type venv pytest -sv tests/verifications/openai_api/test_responses.py \ -k 'file_search' \ --base-url=http://localhost:8321/v1/openai/v1 \ --model=openai/gpt-4o-mini ``` ### Ensure existing vector store integration tests still pass ``` ollama run llama3.2:3b INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" \ llama stack run ./llama_stack/templates/ollama/run.yaml \ --image-type venv \ --env OLLAMA_URL="http://0.0.0.0:11434" LLAMA_STACK_CONFIG=http://localhost:8321 \ pytest -sv tests/integration/vector_io \ --text-model "meta-llama/Llama-3.2-3B-Instruct" \ --embedding-model=all-MiniLM-L6-v2 ``` --------- Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
554ada57b0
commit
941f505eb0
28 changed files with 1105 additions and 24 deletions
Binary file not shown.
|
@ -31,6 +31,25 @@ test_response_web_search:
|
|||
search_context_size: "low"
|
||||
output: "128"
|
||||
|
||||
test_response_file_search:
|
||||
test_name: test_response_file_search
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "llama_experts"
|
||||
input: "How many experts does the Llama 4 Maverick model have?"
|
||||
tools:
|
||||
- type: file_search
|
||||
# vector_store_ids param for file_search tool gets added by the test runner
|
||||
file_content: "Llama 4 Maverick has 128 experts"
|
||||
output: "128"
|
||||
- case_id: "llama_experts_pdf"
|
||||
input: "How many experts does the Llama 4 Maverick model have?"
|
||||
tools:
|
||||
- type: file_search
|
||||
# vector_store_ids param for file_search toolgets added by the test runner
|
||||
file_path: "pdfs/llama_stack_and_models.pdf"
|
||||
output: "128"
|
||||
|
||||
test_response_mcp_tool:
|
||||
test_name: test_response_mcp_tool
|
||||
test_params:
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
|
@ -23,6 +25,31 @@ from tests.verifications.openai_api.fixtures.load import load_test_cases
|
|||
responses_test_cases = load_test_cases("responses")
|
||||
|
||||
|
||||
def _new_vector_store(openai_client, name):
|
||||
# Ensure we don't reuse an existing vector store
|
||||
vector_stores = openai_client.vector_stores.list()
|
||||
for vector_store in vector_stores:
|
||||
if vector_store.name == name:
|
||||
openai_client.vector_stores.delete(vector_store_id=vector_store.id)
|
||||
|
||||
# Create a new vector store
|
||||
vector_store = openai_client.vector_stores.create(
|
||||
name=name,
|
||||
)
|
||||
return vector_store
|
||||
|
||||
|
||||
def _upload_file(openai_client, name, file_path):
|
||||
# Ensure we don't reuse an existing file
|
||||
files = openai_client.files.list()
|
||||
for file in files:
|
||||
if file.filename == name:
|
||||
openai_client.files.delete(file_id=file.id)
|
||||
|
||||
# Upload a text file with our document content
|
||||
return openai_client.files.create(file=open(file_path, "rb"), purpose="assistants")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_basic"]["test_params"]["case"],
|
||||
|
@ -258,6 +285,111 @@ def test_response_non_streaming_web_search(request, openai_client, model, provid
|
|||
assert case["output"].lower() in response.output_text.lower().strip()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_file_search"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_file_search(
|
||||
request, openai_client, model, provider, verification_config, tmp_path, case
|
||||
):
|
||||
if isinstance(openai_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
vector_store = _new_vector_store(openai_client, "test_vector_store")
|
||||
|
||||
if "file_content" in case:
|
||||
file_name = "test_response_non_streaming_file_search.txt"
|
||||
file_path = tmp_path / file_name
|
||||
file_path.write_text(case["file_content"])
|
||||
elif "file_path" in case:
|
||||
file_path = os.path.join(os.path.dirname(__file__), "fixtures", case["file_path"])
|
||||
file_name = os.path.basename(file_path)
|
||||
else:
|
||||
raise ValueError(f"No file content or path provided for case {case['case_id']}")
|
||||
|
||||
file_response = _upload_file(openai_client, file_name, file_path)
|
||||
|
||||
# Attach our file to the vector store
|
||||
file_attach_response = openai_client.vector_stores.files.create(
|
||||
vector_store_id=vector_store.id,
|
||||
file_id=file_response.id,
|
||||
)
|
||||
|
||||
# Wait for the file to be attached
|
||||
while file_attach_response.status == "in_progress":
|
||||
time.sleep(0.1)
|
||||
file_attach_response = openai_client.vector_stores.files.retrieve(
|
||||
vector_store_id=vector_store.id,
|
||||
file_id=file_response.id,
|
||||
)
|
||||
assert file_attach_response.status == "completed", f"Expected file to be attached, got {file_attach_response}"
|
||||
assert not file_attach_response.last_error
|
||||
|
||||
# Update our tools with the right vector store id
|
||||
tools = case["tools"]
|
||||
for tool in tools:
|
||||
if tool["type"] == "file_search":
|
||||
tool["vector_store_ids"] = [vector_store.id]
|
||||
|
||||
# Create the response request, which should query our vector store
|
||||
response = openai_client.responses.create(
|
||||
model=model,
|
||||
input=case["input"],
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
# Verify the file_search_tool was called
|
||||
assert len(response.output) > 1
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].queries # ensure it's some non-empty list
|
||||
assert response.output[0].results
|
||||
assert case["output"].lower() in response.output[0].results[0].text.lower()
|
||||
assert response.output[0].results[0].score > 0
|
||||
|
||||
# Verify the output_text generated by the response
|
||||
assert case["output"].lower() in response.output_text.lower().strip()
|
||||
|
||||
|
||||
def test_response_non_streaming_file_search_empty_vector_store(
|
||||
request, openai_client, model, provider, verification_config
|
||||
):
|
||||
if isinstance(openai_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
vector_store = _new_vector_store(openai_client, "test_vector_store")
|
||||
|
||||
# Create the response request, which should query our vector store
|
||||
response = openai_client.responses.create(
|
||||
model=model,
|
||||
input="How many experts does the Llama 4 Maverick model have?",
|
||||
tools=[{"type": "file_search", "vector_store_ids": [vector_store.id]}],
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
# Verify the file_search_tool was called
|
||||
assert len(response.output) > 1
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].queries # ensure it's some non-empty list
|
||||
assert not response.output[0].results # ensure we don't get any results
|
||||
|
||||
# Verify some output_text was generated by the response
|
||||
assert response.output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_mcp_tool"]["test_params"]["case"],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue