feat: Add support for query rewrite in vector_store.search (#4171)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
Python Package Build Test / build (3.12) (push) Successful in 15s
Python Package Build Test / build (3.13) (push) Successful in 20s
Test External API and Providers / test-external (venv) (push) Failing after 41s
Vector IO Integration Tests / test-matrix (push) Failing after 49s
UI Tests / ui-tests (22) (push) Successful in 51s
Unit Tests / unit-tests (3.13) (push) Failing after 1m27s
Unit Tests / unit-tests (3.12) (push) Failing after 1m45s
Pre-commit / pre-commit (22) (push) Failing after 2m30s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 4m22s

# What does this PR do?

Actualize query rewrite in search API, add
`default_query_expansion_model` and `query_expansion_prompt` in
`VectorStoresConfig`.

Makes `rewrite_query` parameter functional in vector store search.
  - `rewrite_query=false` (default): Use original query
- `rewrite_query=true`: Expand query via LLM, or fail gracefully if no
LLM available

Adds 4 parameters to`VectorStoresConfig`:
- `default_query_expansion_model`: LLM model for query expansion
(optional)
- `query_expansion_prompt`: Custom prompt template (optional, uses
built-in default)
- `query_expansion_max_tokens`: Configurable token limit (default: 100)
- `query_expansion_temperature`: Configurable temperature (default: 0.3)

Enabled `run.yaml`:
```yaml
  vector_stores:
    rewrite_query_params:
      model:
        provider_id: "ollama"
        model_id: "llama3.2:3b-instruct-fp16"
      # prompt defaults to built-in
      # max_tokens defaults to 100
      # temperature defaults to 0.3
```

  Fully customized `run.yaml`:
```yaml
  vector_stores:
    default_provider_id: faiss
    default_embedding_model:
      provider_id: sentence-transformers
      model_id: nomic-ai/nomic-embed-text-v1.5
    rewrite_query_params:
      model:
        provider_id: ollama
        model_id: llama3.2:3b-instruct-fp16
      prompt: "Rewrite this search query to improve retrieval results by expanding it with relevant synonyms and related terms: {query}"
      max_tokens: 100
      temperature: 0.3
```

## Test Plan
Added test and recording

Example script as well:

```python
import asyncio
from llama_stack_client import LlamaStackClient
from io import BytesIO

def gen_file(client, text: str=""):
    file_buffer = BytesIO(text.encode('utf-8'))
    file_buffer.name = "my_file.txt"

    uploaded_file = client.files.create(
        file=file_buffer,
        purpose="assistants"
    )
    return uploaded_file

async def test_query_rewriting():
    client = LlamaStackClient(base_url="http://0.0.0.0:8321/")
    uploaded_file = gen_file(client, "banana banana apple")
    uploaded_file2 = gen_file(client, "orange orange kiwi")

    vs = client.vector_stores.create()
    xf_vs = client.vector_stores.files.create(vector_store_id=vs.id, file_id=uploaded_file.id)
    xf_vs1 = client.vector_stores.files.create(vector_store_id=vs.id, file_id=uploaded_file2.id)
    response1 = client.vector_stores.search(
                vector_store_id=vs.id,
                query="apple",
                max_num_results=3,
                rewrite_query=False
            )
    response2 = client.vector_stores.search(
                vector_store_id=vs.id,
                query="kiwi",
                max_num_results=3,
                rewrite_query=True,
            )

    print(f"\n🔵 Response 1 (rewrite_query=False):\n\033[94m{response1}\033[0m")
    print(f"\n🟢 Response 2 (rewrite_query=True):\n\033[92m{response2}\033[0m")

    for f in [uploaded_file.id, uploaded_file2.id]:
        client.files.delete(file_id=f)
    client.vector_stores.delete(vector_store_id=vs.id)

if __name__ == "__main__":
    asyncio.run(test_query_rewriting())
```

And see the screen shot of the server logs showing it worked. 
<img width="1111" height="826" alt="Screenshot 2025-11-19 at 1 16 03 PM"
src="https://github.com/user-attachments/assets/2d188b44-1fef-4df5-b465-2d6728ca49ce"
/>

Notice the log:
```bash
 Query rewritten:
         'kiwi' → 'kiwi, a small brown or green fruit native to New Zealand, or a person having a fuzzy brown outer skin similar in appearance.'
```
So `kiwi` was expanded.

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
Co-authored-by: Matthew Farrellee <matt@cs.wisc.edu>
This commit is contained in:
Francisco Javier Arceo 2025-12-10 10:06:19 -05:00 committed by GitHub
parent ff375f1abb
commit 95b2948d11
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 7636 additions and 20 deletions

View file

@ -18,6 +18,7 @@ from llama_stack.core.storage.datatypes import (
StorageConfig,
)
from llama_stack.log import LoggingConfig
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_REWRITE_PROMPT
from llama_stack_api import (
Api,
Benchmark,
@ -349,6 +350,27 @@ class QualifiedModel(BaseModel):
model_id: str
class RewriteQueryParams(BaseModel):
"""Parameters for query rewriting/expansion."""
model: QualifiedModel | None = Field(
default=None,
description="LLM model for query rewriting/expansion in vector search.",
)
prompt: str = Field(
default=DEFAULT_QUERY_REWRITE_PROMPT,
description="Prompt template for query rewriting. Use {query} as placeholder for the original query.",
)
max_tokens: int = Field(
default=100,
description="Maximum number of tokens for query expansion responses.",
)
temperature: float = Field(
default=0.3,
description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).",
)
class VectorStoresConfig(BaseModel):
"""Configuration for vector stores in the stack."""
@ -360,6 +382,10 @@ class VectorStoresConfig(BaseModel):
default=None,
description="Default embedding model configuration for vector stores.",
)
rewrite_query_params: RewriteQueryParams | None = Field(
default=None,
description="Parameters for query rewriting/expansion. None disables query rewriting.",
)
class SafetyConfig(BaseModel):

View file

@ -199,6 +199,13 @@ def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str,
)
}
# Add inference as an optional dependency for vector_io to enable query rewriting
optional_deps = []
deps_list = [info.routing_table_api.value]
if info.router_api == Api.vector_io:
optional_deps = [Api.inference]
deps_list.append(Api.inference.value)
specs[info.router_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__autorouted__",
@ -209,7 +216,8 @@ def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str,
module="llama_stack.core.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
deps__=([info.routing_table_api.value]),
optional_api_dependencies=optional_deps,
deps__=deps_list,
),
)
}
@ -315,6 +323,13 @@ async def instantiate_providers(
api = Api(api_str)
impls[api] = impl
# Post-instantiation: Inject VectorIORouter into VectorStoresRoutingTable
if Api.vector_io in impls and Api.vector_stores in impls:
vector_io_router = impls[Api.vector_io]
vector_stores_routing_table = impls[Api.vector_stores]
if hasattr(vector_stores_routing_table, "vector_io_router"):
vector_stores_routing_table.vector_io_router = vector_io_router
return impls

View file

@ -87,6 +87,7 @@ async def get_auto_router_impl(
api_to_dep_impl["store"] = inference_store
elif api == Api.vector_io:
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
api_to_dep_impl["inference_api"] = deps.get(Api.inference)
elif api == Api.safety:
api_to_dep_impl["safety_config"] = run_config.safety

View file

@ -16,12 +16,15 @@ from llama_stack_api import (
Chunk,
HealthResponse,
HealthStatus,
Inference,
InterleavedContent,
ModelNotFoundError,
ModelType,
ModelTypeError,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
OpenAICreateVectorStoreRequestWithExtraBody,
OpenAIUserMessageParam,
QueryChunksResponse,
RoutingTable,
SearchRankingOptions,
@ -51,10 +54,11 @@ class VectorIORouter(VectorIO):
self,
routing_table: RoutingTable,
vector_stores_config: VectorStoresConfig | None = None,
inference_api: Inference | None = None,
) -> None:
logger.debug("Initializing VectorIORouter")
self.routing_table = routing_table
self.vector_stores_config = vector_stores_config
self.inference_api = inference_api
async def initialize(self) -> None:
logger.debug("VectorIORouter.initialize")
@ -64,6 +68,46 @@ class VectorIORouter(VectorIO):
logger.debug("VectorIORouter.shutdown")
pass
async def _rewrite_query_for_search(self, query: str) -> str:
"""Rewrite a search query using the configured LLM model for better retrieval results."""
if (
not self.vector_stores_config
or not self.vector_stores_config.rewrite_query_params
or not self.vector_stores_config.rewrite_query_params.model
):
logger.warning(
"User is trying to use vector_store query rewriting, but it is not configured. Please configure rewrite_query_params.model in vector_stores config."
)
raise ValueError("Query rewriting is not available")
if not self.inference_api:
logger.warning("Query rewriting requires inference API but it is not available")
raise ValueError("Query rewriting is not available")
model = self.vector_stores_config.rewrite_query_params.model
model_id = f"{model.provider_id}/{model.model_id}"
prompt = self.vector_stores_config.rewrite_query_params.prompt.format(query=query)
request = OpenAIChatCompletionRequestWithExtraBody(
model=model_id,
messages=[OpenAIUserMessageParam(role="user", content=prompt)],
max_tokens=self.vector_stores_config.rewrite_query_params.max_tokens or 100,
temperature=self.vector_stores_config.rewrite_query_params.temperature or 0.3,
)
try:
response = await self.inference_api.openai_chat_completion(request)
content = response.choices[0].message.content
if content is None:
logger.error(f"LLM returned None content for query rewriting. Model: {model_id}")
raise RuntimeError("Query rewrite failed due to an internal error")
rewritten_query: str = content.strip()
return rewritten_query
except Exception as e:
logger.error(f"Query rewrite failed with LLM call error. Model: {model_id}, Error: {e}")
raise RuntimeError("Query rewrite failed due to an internal error") from e
async def _get_embedding_model_dimension(self, embedding_model_id: str) -> int:
"""Get the embedding dimension for a specific embedding model."""
all_models = await self.routing_table.get_all_with_type("model")
@ -292,14 +336,24 @@ class VectorIORouter(VectorIO):
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
# Handle query rewriting at the router level
search_query = query
if rewrite_query:
if isinstance(query, list):
original_query = " ".join(query)
else:
original_query = query
search_query = await self._rewrite_query_for_search(original_query)
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
query=search_query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
rewrite_query=False, # Already handled at router level
search_mode=search_mode,
)

View file

@ -40,6 +40,15 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
Only provides internal routing functionality for VectorIORouter.
"""
def __init__(
self,
impls_by_provider_id: dict[str, Any],
dist_registry: Any,
policy: list[Any],
) -> None:
super().__init__(impls_by_provider_id, dist_registry, policy)
self.vector_io_router = None # Will be set post-instantiation
# Internal methods only - no public API exposure
async def register_vector_store(
@ -133,6 +142,20 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
await self.assert_action_allowed("read", "vector_store", vector_store_id)
# Delegate to VectorIORouter if available (which handles query rewriting)
if self.vector_io_router is not None:
return await self.vector_io_router.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
search_mode=search_mode,
)
# Fallback to direct provider call if VectorIORouter not available
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id,

View file

@ -16,7 +16,7 @@ import yaml
from pydantic import BaseModel
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, SafetyConfig, StackConfig, VectorStoresConfig
from llama_stack.core.datatypes import Provider, QualifiedModel, SafetyConfig, StackConfig, VectorStoresConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
@ -221,35 +221,71 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig
if vector_stores_config is None:
return
default_embedding_model = vector_stores_config.default_embedding_model
if default_embedding_model is None:
return
# Validate default embedding model
if vector_stores_config.default_embedding_model is not None:
await _validate_embedding_model(vector_stores_config.default_embedding_model, impls)
provider_id = default_embedding_model.provider_id
model_id = default_embedding_model.model_id
default_model_id = f"{provider_id}/{model_id}"
# Validate rewrite query params
if vector_stores_config.rewrite_query_params:
if vector_stores_config.rewrite_query_params.model:
await _validate_rewrite_query_model(vector_stores_config.rewrite_query_params.model, impls)
if "{query}" not in vector_stores_config.rewrite_query_params.prompt:
raise ValueError("'{query}' placeholder is required in the prompt template")
async def _validate_embedding_model(embedding_model: QualifiedModel, impls: dict[Api, Any]) -> None:
"""Validate that an embedding model exists and has required metadata."""
provider_id = embedding_model.provider_id
model_id = embedding_model.model_id
model_identifier = f"{provider_id}/{model_id}"
if Api.models not in impls:
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
raise ValueError(f"Models API is not available but vector_stores config requires model '{model_identifier}'")
models_impl = impls[Api.models]
response = await models_impl.list_models()
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
default_model = models_list.get(default_model_id)
if default_model is None:
raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}")
model = models_list.get(model_identifier)
if model is None:
raise ValueError(
f"Embedding model '{model_identifier}' not found. Available embedding models: {list(models_list.keys())}"
)
embedding_dimension = default_model.metadata.get("embedding_dimension")
embedding_dimension = model.metadata.get("embedding_dimension")
if embedding_dimension is None:
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
raise ValueError(f"Embedding model '{model_identifier}' is missing 'embedding_dimension' in metadata")
try:
int(embedding_dimension)
except ValueError as err:
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
logger.debug(f"Validated embedding model: {model_identifier} (dimension: {embedding_dimension})")
async def _validate_rewrite_query_model(rewrite_query_model: QualifiedModel, impls: dict[Api, Any]) -> None:
"""Validate that a rewrite query model exists and is accessible."""
provider_id = rewrite_query_model.provider_id
model_id = rewrite_query_model.model_id
model_identifier = f"{provider_id}/{model_id}"
if Api.models not in impls:
raise ValueError(
f"Models API is not available but vector_stores config requires rewrite query model '{model_identifier}'"
)
models_impl = impls[Api.models]
response = await models_impl.list_models()
llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"}
model = llm_models_list.get(model_identifier)
if model is None:
raise ValueError(
f"Rewrite query model '{model_identifier}' not found. Available LLM models: {list(llm_models_list.keys())}"
)
logger.debug(f"Validated rewrite query model: {model_identifier}")
async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]):

View file

@ -3,3 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .constants import DEFAULT_QUERY_REWRITE_PROMPT
__all__ = ["DEFAULT_QUERY_REWRITE_PROMPT"]

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Default prompt template for query rewriting in vector search
DEFAULT_QUERY_REWRITE_PROMPT = "Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:"

View file

@ -592,7 +592,11 @@ class OpenAIVectorStoreMixin(ABC):
str | None
) = "vector", # Using str instead of Literal due to OpenAPI schema generator limitations
) -> VectorStoreSearchResponsePage:
"""Search for chunks in a vector store."""
"""Search for chunks in a vector store.
Note: Query rewriting is handled at the router level, not here.
The rewrite_query parameter is kept for API compatibility but is ignored.
"""
max_num_results = max_num_results or 10
# Validate search_mode