feat: Actualize query rewrite in search API

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

adding query expansion model to vector store config

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-11-16 23:56:59 -05:00
parent dabebdd230
commit 61a4738a12
20 changed files with 7381 additions and 0 deletions

View file

@ -376,6 +376,14 @@ class VectorStoresConfig(BaseModel):
default=None,
description="Default embedding model configuration for vector stores.",
)
default_query_expansion_model: QualifiedModel | None = Field(
default=None,
description="Default LLM model for query expansion/rewriting in vector search.",
)
query_expansion_prompt: str = Field(
default="Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:",
description="Prompt template for query expansion. Use {query} as placeholder for the original query.",
)
class SafetyConfig(BaseModel):

View file

@ -99,6 +99,12 @@ class VectorIORouter(VectorIO):
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
# Ensure params dict exists and add vector_stores_config for query rewriting
if params is None:
params = {}
params["vector_stores_config"] = self.vector_stores_config
return await provider.query_chunks(vector_store_id, query, params)
# OpenAI Vector Stores API endpoints

View file

@ -288,5 +288,13 @@ vector_stores:
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
Return only the improved query, no explanations:
{query}
Improved query:'
safety:
default_shield_id: llama-guard

View file

@ -279,5 +279,13 @@ vector_stores:
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
Return only the improved query, no explanations:
{query}
Improved query:'
safety:
default_shield_id: llama-guard

View file

@ -291,5 +291,13 @@ vector_stores:
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
Return only the improved query, no explanations:
{query}
Improved query:'
safety:
default_shield_id: llama-guard

View file

@ -282,5 +282,13 @@ vector_stores:
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
Return only the improved query, no explanations:
{query}
Improved query:'
safety:
default_shield_id: llama-guard

View file

@ -288,5 +288,13 @@ vector_stores:
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
Return only the improved query, no explanations:
{query}
Improved query:'
safety:
default_shield_id: llama-guard

View file

@ -279,5 +279,13 @@ vector_stores:
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
Return only the improved query, no explanations:
{query}
Improved query:'
safety:
default_shield_id: llama-guard

View file

@ -611,6 +611,7 @@ class OpenAIVectorStoreMixin(ABC):
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
"score_threshold": score_threshold,
"mode": search_mode,
"rewrite_query": rewrite_query,
}
# TODO: Add support for ranking_options.ranker

View file

@ -17,6 +17,7 @@ import numpy as np
from numpy.typing import NDArray
from pydantic import BaseModel
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -34,6 +35,11 @@ from llama_stack_api import (
RAGDocument,
VectorStore,
)
from llama_stack_api.inference import (
OpenAIChatCompletionRequestWithExtraBody,
OpenAIUserMessageParam,
)
from llama_stack_api.models import ModelType
log = get_logger(name=__name__, category="providers::utils")
@ -262,6 +268,7 @@ class VectorStoreWithIndex:
vector_store: VectorStore
index: EmbeddingIndex
inference_api: Api.inference
vector_stores_config: VectorStoresConfig | None = None
async def insert_chunks(
self,
@ -296,6 +303,11 @@ class VectorStoreWithIndex:
) -> QueryChunksResponse:
if params is None:
params = {}
# Extract configuration if provided by router
if "vector_stores_config" in params:
self.vector_stores_config = params["vector_stores_config"]
k = params.get("max_chunks", 3)
mode = params.get("mode")
score_threshold = params.get("score_threshold", 0.0)
@ -318,6 +330,11 @@ class VectorStoreWithIndex:
reranker_params = {"impact_factor": k_value}
query_string = interleaved_content_as_str(query)
# Apply query rewriting if enabled
if params.get("rewrite_query", False):
query_string = await self._rewrite_query_for_search(query_string)
if mode == "keyword":
return await self.index.query_keyword(query_string, k, score_threshold)
@ -333,3 +350,67 @@ class VectorStoreWithIndex:
)
else:
return await self.index.query_vector(query_vector, k, score_threshold)
async def _rewrite_query_for_search(self, query: str) -> str:
"""Rewrite the user query to improve vector search performance.
:param query: The original user query
:returns: The rewritten query optimized for vector search
"""
# Check if query expansion model is configured
if not self.vector_stores_config or not self.vector_stores_config.default_query_expansion_model:
raise ValueError("No default_query_expansion_model configured for query rewriting")
# Use the configured model
expansion_model = self.vector_stores_config.default_query_expansion_model
chat_model = f"{expansion_model.provider_id}/{expansion_model.model_id}"
# Validate that the model is available and is an LLM
try:
models_response = await self.inference_api.routing_table.list_models()
except Exception as e:
raise RuntimeError(f"Failed to list available models for validation: {e}") from e
model_found = False
for model in models_response.data:
if model.identifier == chat_model:
if model.model_type != ModelType.llm:
raise ValueError(
f"Configured query expansion model '{chat_model}' is not an LLM model "
f"(found type: {model.model_type}). Query rewriting requires an LLM model."
)
model_found = True
break
if not model_found:
available_llm_models = [m.identifier for m in models_response.data if m.model_type == ModelType.llm]
raise ValueError(
f"Configured query expansion model '{chat_model}' is not available. "
f"Available LLM models: {available_llm_models}"
)
# Use the configured prompt (has a default value)
rewrite_prompt = self.vector_stores_config.query_expansion_prompt.format(query=query)
chat_request = OpenAIChatCompletionRequestWithExtraBody(
model=chat_model,
messages=[
OpenAIUserMessageParam(
role="user",
content=rewrite_prompt,
)
],
max_tokens=100,
)
try:
response = await self.inference_api.openai_chat_completion(chat_request)
except Exception as e:
raise RuntimeError(f"Failed to generate rewritten query: {e}") from e
if response.choices and len(response.choices) > 0:
rewritten_query = response.choices[0].message.content.strip()
log.info(f"Query rewritten: '{query}''{rewritten_query}'")
return rewritten_query
else:
raise RuntimeError("No response received from LLM model for query rewriting")