mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 06:02:36 +00:00
feat: Add support for query rewrite in vector_store.search (#4171)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
Python Package Build Test / build (3.12) (push) Successful in 15s
Python Package Build Test / build (3.13) (push) Successful in 20s
Test External API and Providers / test-external (venv) (push) Failing after 41s
Vector IO Integration Tests / test-matrix (push) Failing after 49s
UI Tests / ui-tests (22) (push) Successful in 51s
Unit Tests / unit-tests (3.13) (push) Failing after 1m27s
Unit Tests / unit-tests (3.12) (push) Failing after 1m45s
Pre-commit / pre-commit (22) (push) Failing after 2m30s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 4m22s
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
Python Package Build Test / build (3.12) (push) Successful in 15s
Python Package Build Test / build (3.13) (push) Successful in 20s
Test External API and Providers / test-external (venv) (push) Failing after 41s
Vector IO Integration Tests / test-matrix (push) Failing after 49s
UI Tests / ui-tests (22) (push) Successful in 51s
Unit Tests / unit-tests (3.13) (push) Failing after 1m27s
Unit Tests / unit-tests (3.12) (push) Failing after 1m45s
Pre-commit / pre-commit (22) (push) Failing after 2m30s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 4m22s
# What does this PR do?
Actualize query rewrite in search API, add
`default_query_expansion_model` and `query_expansion_prompt` in
`VectorStoresConfig`.
Makes `rewrite_query` parameter functional in vector store search.
- `rewrite_query=false` (default): Use original query
- `rewrite_query=true`: Expand query via LLM, or fail gracefully if no
LLM available
Adds 4 parameters to`VectorStoresConfig`:
- `default_query_expansion_model`: LLM model for query expansion
(optional)
- `query_expansion_prompt`: Custom prompt template (optional, uses
built-in default)
- `query_expansion_max_tokens`: Configurable token limit (default: 100)
- `query_expansion_temperature`: Configurable temperature (default: 0.3)
Enabled `run.yaml`:
```yaml
vector_stores:
rewrite_query_params:
model:
provider_id: "ollama"
model_id: "llama3.2:3b-instruct-fp16"
# prompt defaults to built-in
# max_tokens defaults to 100
# temperature defaults to 0.3
```
Fully customized `run.yaml`:
```yaml
vector_stores:
default_provider_id: faiss
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
rewrite_query_params:
model:
provider_id: ollama
model_id: llama3.2:3b-instruct-fp16
prompt: "Rewrite this search query to improve retrieval results by expanding it with relevant synonyms and related terms: {query}"
max_tokens: 100
temperature: 0.3
```
## Test Plan
Added test and recording
Example script as well:
```python
import asyncio
from llama_stack_client import LlamaStackClient
from io import BytesIO
def gen_file(client, text: str=""):
file_buffer = BytesIO(text.encode('utf-8'))
file_buffer.name = "my_file.txt"
uploaded_file = client.files.create(
file=file_buffer,
purpose="assistants"
)
return uploaded_file
async def test_query_rewriting():
client = LlamaStackClient(base_url="http://0.0.0.0:8321/")
uploaded_file = gen_file(client, "banana banana apple")
uploaded_file2 = gen_file(client, "orange orange kiwi")
vs = client.vector_stores.create()
xf_vs = client.vector_stores.files.create(vector_store_id=vs.id, file_id=uploaded_file.id)
xf_vs1 = client.vector_stores.files.create(vector_store_id=vs.id, file_id=uploaded_file2.id)
response1 = client.vector_stores.search(
vector_store_id=vs.id,
query="apple",
max_num_results=3,
rewrite_query=False
)
response2 = client.vector_stores.search(
vector_store_id=vs.id,
query="kiwi",
max_num_results=3,
rewrite_query=True,
)
print(f"\n🔵 Response 1 (rewrite_query=False):\n\033[94m{response1}\033[0m")
print(f"\n🟢 Response 2 (rewrite_query=True):\n\033[92m{response2}\033[0m")
for f in [uploaded_file.id, uploaded_file2.id]:
client.files.delete(file_id=f)
client.vector_stores.delete(vector_store_id=vs.id)
if __name__ == "__main__":
asyncio.run(test_query_rewriting())
```
And see the screen shot of the server logs showing it worked.
<img width="1111" height="826" alt="Screenshot 2025-11-19 at 1 16 03 PM"
src="https://github.com/user-attachments/assets/2d188b44-1fef-4df5-b465-2d6728ca49ce"
/>
Notice the log:
```bash
Query rewritten:
'kiwi' → 'kiwi, a small brown or green fruit native to New Zealand, or a person having a fuzzy brown outer skin similar in appearance.'
```
So `kiwi` was expanded.
---------
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
Co-authored-by: Matthew Farrellee <matt@cs.wisc.edu>
This commit is contained in:
parent
ff375f1abb
commit
95b2948d11
22 changed files with 7636 additions and 20 deletions
|
|
@ -153,3 +153,156 @@ async def test_create_vector_store_with_wrong_model_type_raises_error():
|
|||
|
||||
with pytest.raises(ModelTypeError, match="Model 'text-model' is of type"):
|
||||
await router.openai_create_vector_store(request)
|
||||
|
||||
|
||||
async def test_query_rewrite_functionality():
|
||||
"""Test query rewriting at the router level."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from llama_stack.core.datatypes import QualifiedModel, RewriteQueryParams, VectorStoresConfig
|
||||
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_REWRITE_PROMPT
|
||||
from llama_stack_api import VectorStoreSearchResponsePage
|
||||
|
||||
mock_routing_table = Mock()
|
||||
|
||||
# Mock provider that returns search results
|
||||
mock_provider = Mock()
|
||||
mock_search_response = VectorStoreSearchResponsePage(search_query=["rewritten test query"], data=[], has_more=False)
|
||||
mock_provider.openai_search_vector_store = AsyncMock(return_value=mock_search_response)
|
||||
mock_routing_table.get_provider_impl = AsyncMock(return_value=mock_provider)
|
||||
|
||||
# Mock inference API for query rewriting
|
||||
mock_inference_api = Mock()
|
||||
mock_inference_api.openai_chat_completion = AsyncMock(
|
||||
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="rewritten test query"))])
|
||||
)
|
||||
|
||||
# Create config with rewrite params
|
||||
vector_stores_config = VectorStoresConfig(
|
||||
rewrite_query_params=RewriteQueryParams(
|
||||
model=QualifiedModel(provider_id="test", model_id="llama"),
|
||||
max_tokens=100,
|
||||
temperature=0.3,
|
||||
)
|
||||
)
|
||||
|
||||
router = VectorIORouter(mock_routing_table, vector_stores_config, mock_inference_api)
|
||||
|
||||
# Test query rewrite with rewrite_query=True
|
||||
result = await router.openai_search_vector_store(
|
||||
vector_store_id="vs_123",
|
||||
query="test query",
|
||||
rewrite_query=True,
|
||||
max_num_results=5,
|
||||
)
|
||||
|
||||
# Verify chat completion was called for query rewriting
|
||||
assert mock_inference_api.openai_chat_completion.called
|
||||
chat_call_args = mock_inference_api.openai_chat_completion.call_args[0][0]
|
||||
assert chat_call_args.model == "test/llama"
|
||||
|
||||
# Verify default prompt is used
|
||||
prompt_text = chat_call_args.messages[0].content
|
||||
expected_prompt = DEFAULT_QUERY_REWRITE_PROMPT.format(query="test query")
|
||||
assert prompt_text == expected_prompt
|
||||
|
||||
# Verify provider was called with rewritten query and rewrite_query=False
|
||||
mock_provider.openai_search_vector_store.assert_called_once()
|
||||
call_kwargs = mock_provider.openai_search_vector_store.call_args.kwargs
|
||||
assert call_kwargs["query"] == "rewritten test query"
|
||||
assert call_kwargs["rewrite_query"] is False # Should be False since router handled it
|
||||
|
||||
assert result is not None
|
||||
|
||||
|
||||
async def test_query_rewrite_error_when_not_configured():
|
||||
"""Test that query rewriting fails with proper error when not configured."""
|
||||
mock_routing_table = Mock()
|
||||
mock_provider = Mock()
|
||||
mock_routing_table.get_provider_impl = AsyncMock(return_value=mock_provider)
|
||||
|
||||
# No config or inference API
|
||||
router = VectorIORouter(mock_routing_table)
|
||||
|
||||
with pytest.raises(ValueError, match="Query rewriting is not available"):
|
||||
await router.openai_search_vector_store(
|
||||
vector_store_id="vs_123",
|
||||
query="test query",
|
||||
rewrite_query=True,
|
||||
max_num_results=5,
|
||||
)
|
||||
|
||||
|
||||
async def test_query_rewrite_with_custom_prompt():
|
||||
"""Test query rewriting with custom prompt."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from llama_stack.core.datatypes import QualifiedModel, RewriteQueryParams, VectorStoresConfig
|
||||
from llama_stack_api import VectorStoreSearchResponsePage
|
||||
|
||||
mock_routing_table = Mock()
|
||||
|
||||
mock_provider = Mock()
|
||||
mock_search_response = VectorStoreSearchResponsePage(search_query=["custom rewrite"], data=[], has_more=False)
|
||||
mock_provider.openai_search_vector_store = AsyncMock(return_value=mock_search_response)
|
||||
mock_routing_table.get_provider_impl = AsyncMock(return_value=mock_provider)
|
||||
|
||||
mock_inference_api = Mock()
|
||||
mock_inference_api.openai_chat_completion = AsyncMock(
|
||||
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="custom rewrite"))])
|
||||
)
|
||||
|
||||
vector_stores_config = VectorStoresConfig(
|
||||
rewrite_query_params=RewriteQueryParams(
|
||||
model=QualifiedModel(provider_id="test", model_id="llama"),
|
||||
prompt="Custom prompt: {query}",
|
||||
max_tokens=150,
|
||||
temperature=0.7,
|
||||
)
|
||||
)
|
||||
|
||||
router = VectorIORouter(mock_routing_table, vector_stores_config, mock_inference_api)
|
||||
|
||||
await router.openai_search_vector_store(
|
||||
vector_store_id="vs_123",
|
||||
query="test query",
|
||||
rewrite_query=True,
|
||||
max_num_results=5,
|
||||
)
|
||||
|
||||
# Verify custom prompt was used
|
||||
chat_call_args = mock_inference_api.openai_chat_completion.call_args[0][0]
|
||||
assert chat_call_args.messages[0].content == "Custom prompt: test query"
|
||||
assert chat_call_args.max_tokens == 150
|
||||
assert chat_call_args.temperature == 0.7
|
||||
|
||||
|
||||
async def test_search_without_rewrite():
|
||||
"""Test that search without rewrite_query doesn't call inference API."""
|
||||
from llama_stack_api import VectorStoreSearchResponsePage
|
||||
|
||||
mock_routing_table = Mock()
|
||||
|
||||
mock_provider = Mock()
|
||||
mock_search_response = VectorStoreSearchResponsePage(search_query=["test query"], data=[], has_more=False)
|
||||
mock_provider.openai_search_vector_store = AsyncMock(return_value=mock_search_response)
|
||||
mock_routing_table.get_provider_impl = AsyncMock(return_value=mock_provider)
|
||||
|
||||
mock_inference_api = Mock()
|
||||
mock_inference_api.openai_chat_completion = AsyncMock()
|
||||
|
||||
router = VectorIORouter(mock_routing_table, inference_api=mock_inference_api)
|
||||
|
||||
await router.openai_search_vector_store(
|
||||
vector_store_id="vs_123",
|
||||
query="test query",
|
||||
rewrite_query=False,
|
||||
max_num_results=5,
|
||||
)
|
||||
|
||||
# Verify inference API was NOT called
|
||||
assert not mock_inference_api.openai_chat_completion.called
|
||||
|
||||
# Verify provider was called with original query
|
||||
call_kwargs = mock_provider.openai_search_vector_store.call_args.kwargs
|
||||
assert call_kwargs["query"] == "test query"
|
||||
|
|
|
|||
|
|
@ -10,7 +10,13 @@ from unittest.mock import AsyncMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.datatypes import QualifiedModel, SafetyConfig, StackConfig, VectorStoresConfig
|
||||
from llama_stack.core.datatypes import (
|
||||
QualifiedModel,
|
||||
RewriteQueryParams,
|
||||
SafetyConfig,
|
||||
StackConfig,
|
||||
VectorStoresConfig,
|
||||
)
|
||||
from llama_stack.core.stack import validate_safety_config, validate_vector_stores_config
|
||||
from llama_stack.core.storage.datatypes import ServerStoresConfig, StorageConfig
|
||||
from llama_stack_api import Api, ListModelsResponse, ListShieldsResponse, Model, ModelType, Shield
|
||||
|
|
@ -82,6 +88,17 @@ class TestVectorStoresValidation:
|
|||
|
||||
await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})
|
||||
|
||||
async def test_validate_rewrite_query_prompt_missing_placeholder(self):
|
||||
"""Test validation fails when prompt template is missing {query} placeholder."""
|
||||
config = VectorStoresConfig(
|
||||
rewrite_query_params=RewriteQueryParams(
|
||||
prompt="This prompt has no placeholder",
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="'\\{query\\}' placeholder is required"):
|
||||
await validate_vector_stores_config(config, {})
|
||||
|
||||
|
||||
class TestSafetyConfigValidation:
|
||||
async def test_validate_success(self):
|
||||
|
|
|
|||
|
|
@ -1230,3 +1230,40 @@ async def test_embedding_config_required_model_missing(vector_io_adapter):
|
|||
|
||||
with pytest.raises(ValueError, match="embedding_model is required"):
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
|
||||
async def test_search_vector_store_ignores_rewrite_query(vector_io_adapter):
|
||||
"""Test that the mixin ignores rewrite_query parameter since rewriting is done at router level."""
|
||||
from llama_stack_api import QueryChunksResponse
|
||||
|
||||
# Create an OpenAI vector store for testing directly in the adapter's cache
|
||||
vector_store_id = "test_store_rewrite"
|
||||
openai_vector_store = {
|
||||
"id": vector_store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_store_id": "test_db",
|
||||
"embedding_model": "test/embedding",
|
||||
}
|
||||
vector_io_adapter.openai_vector_stores[vector_store_id] = openai_vector_store
|
||||
|
||||
# Mock query_chunks response from adapter
|
||||
mock_response = QueryChunksResponse(chunks=[], scores=[])
|
||||
|
||||
async def mock_query_chunks(*args, **kwargs):
|
||||
return mock_response
|
||||
|
||||
vector_io_adapter.query_chunks = mock_query_chunks
|
||||
|
||||
# Test that rewrite_query=True doesn't cause an error (it's ignored at mixin level)
|
||||
# The mixin should process the search request without attempting to rewrite the query
|
||||
result = await vector_io_adapter.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query="test query",
|
||||
rewrite_query=True, # This should be ignored at mixin level
|
||||
max_num_results=5,
|
||||
)
|
||||
|
||||
# Search should succeed - the mixin ignores rewrite_query and just does the search
|
||||
assert result is not None
|
||||
assert result.search_query == ["test query"] # Original query preserved
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue