mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat: Actualize query rewrite in search API
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> adding query expansion model to vector store config Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
dabebdd230
commit
61a4738a12
20 changed files with 7381 additions and 0 deletions
|
|
@ -376,6 +376,14 @@ class VectorStoresConfig(BaseModel):
|
|||
default=None,
|
||||
description="Default embedding model configuration for vector stores.",
|
||||
)
|
||||
default_query_expansion_model: QualifiedModel | None = Field(
|
||||
default=None,
|
||||
description="Default LLM model for query expansion/rewriting in vector search.",
|
||||
)
|
||||
query_expansion_prompt: str = Field(
|
||||
default="Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:",
|
||||
description="Prompt template for query expansion. Use {query} as placeholder for the original query.",
|
||||
)
|
||||
|
||||
|
||||
class SafetyConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -99,6 +99,12 @@ class VectorIORouter(VectorIO):
|
|||
) -> QueryChunksResponse:
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_store_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
|
||||
# Ensure params dict exists and add vector_stores_config for query rewriting
|
||||
if params is None:
|
||||
params = {}
|
||||
params["vector_stores_config"] = self.vector_stores_config
|
||||
|
||||
return await provider.query_chunks(vector_store_id, query, params)
|
||||
|
||||
# OpenAI Vector Stores API endpoints
|
||||
|
|
|
|||
|
|
@ -288,5 +288,13 @@ vector_stores:
|
|||
default_embedding_model:
|
||||
provider_id: sentence-transformers
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
|
||||
Return only the improved query, no explanations:
|
||||
|
||||
|
||||
{query}
|
||||
|
||||
|
||||
Improved query:'
|
||||
safety:
|
||||
default_shield_id: llama-guard
|
||||
|
|
|
|||
|
|
@ -279,5 +279,13 @@ vector_stores:
|
|||
default_embedding_model:
|
||||
provider_id: sentence-transformers
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
|
||||
Return only the improved query, no explanations:
|
||||
|
||||
|
||||
{query}
|
||||
|
||||
|
||||
Improved query:'
|
||||
safety:
|
||||
default_shield_id: llama-guard
|
||||
|
|
|
|||
|
|
@ -291,5 +291,13 @@ vector_stores:
|
|||
default_embedding_model:
|
||||
provider_id: sentence-transformers
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
|
||||
Return only the improved query, no explanations:
|
||||
|
||||
|
||||
{query}
|
||||
|
||||
|
||||
Improved query:'
|
||||
safety:
|
||||
default_shield_id: llama-guard
|
||||
|
|
|
|||
|
|
@ -282,5 +282,13 @@ vector_stores:
|
|||
default_embedding_model:
|
||||
provider_id: sentence-transformers
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
|
||||
Return only the improved query, no explanations:
|
||||
|
||||
|
||||
{query}
|
||||
|
||||
|
||||
Improved query:'
|
||||
safety:
|
||||
default_shield_id: llama-guard
|
||||
|
|
|
|||
|
|
@ -288,5 +288,13 @@ vector_stores:
|
|||
default_embedding_model:
|
||||
provider_id: sentence-transformers
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
|
||||
Return only the improved query, no explanations:
|
||||
|
||||
|
||||
{query}
|
||||
|
||||
|
||||
Improved query:'
|
||||
safety:
|
||||
default_shield_id: llama-guard
|
||||
|
|
|
|||
|
|
@ -279,5 +279,13 @@ vector_stores:
|
|||
default_embedding_model:
|
||||
provider_id: sentence-transformers
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
|
||||
Return only the improved query, no explanations:
|
||||
|
||||
|
||||
{query}
|
||||
|
||||
|
||||
Improved query:'
|
||||
safety:
|
||||
default_shield_id: llama-guard
|
||||
|
|
|
|||
|
|
@ -611,6 +611,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
|
||||
"score_threshold": score_threshold,
|
||||
"mode": search_mode,
|
||||
"rewrite_query": rewrite_query,
|
||||
}
|
||||
# TODO: Add support for ranking_options.ranker
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ import numpy as np
|
|||
from numpy.typing import NDArray
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
|
@ -34,6 +35,11 @@ from llama_stack_api import (
|
|||
RAGDocument,
|
||||
VectorStore,
|
||||
)
|
||||
from llama_stack_api.inference import (
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack_api.models import ModelType
|
||||
|
||||
log = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
|
@ -262,6 +268,7 @@ class VectorStoreWithIndex:
|
|||
vector_store: VectorStore
|
||||
index: EmbeddingIndex
|
||||
inference_api: Api.inference
|
||||
vector_stores_config: VectorStoresConfig | None = None
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
|
|
@ -296,6 +303,11 @@ class VectorStoreWithIndex:
|
|||
) -> QueryChunksResponse:
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
# Extract configuration if provided by router
|
||||
if "vector_stores_config" in params:
|
||||
self.vector_stores_config = params["vector_stores_config"]
|
||||
|
||||
k = params.get("max_chunks", 3)
|
||||
mode = params.get("mode")
|
||||
score_threshold = params.get("score_threshold", 0.0)
|
||||
|
|
@ -318,6 +330,11 @@ class VectorStoreWithIndex:
|
|||
reranker_params = {"impact_factor": k_value}
|
||||
|
||||
query_string = interleaved_content_as_str(query)
|
||||
|
||||
# Apply query rewriting if enabled
|
||||
if params.get("rewrite_query", False):
|
||||
query_string = await self._rewrite_query_for_search(query_string)
|
||||
|
||||
if mode == "keyword":
|
||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
|
|
@ -333,3 +350,67 @@ class VectorStoreWithIndex:
|
|||
)
|
||||
else:
|
||||
return await self.index.query_vector(query_vector, k, score_threshold)
|
||||
|
||||
async def _rewrite_query_for_search(self, query: str) -> str:
|
||||
"""Rewrite the user query to improve vector search performance.
|
||||
|
||||
:param query: The original user query
|
||||
:returns: The rewritten query optimized for vector search
|
||||
"""
|
||||
# Check if query expansion model is configured
|
||||
if not self.vector_stores_config or not self.vector_stores_config.default_query_expansion_model:
|
||||
raise ValueError("No default_query_expansion_model configured for query rewriting")
|
||||
|
||||
# Use the configured model
|
||||
expansion_model = self.vector_stores_config.default_query_expansion_model
|
||||
chat_model = f"{expansion_model.provider_id}/{expansion_model.model_id}"
|
||||
|
||||
# Validate that the model is available and is an LLM
|
||||
try:
|
||||
models_response = await self.inference_api.routing_table.list_models()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to list available models for validation: {e}") from e
|
||||
|
||||
model_found = False
|
||||
for model in models_response.data:
|
||||
if model.identifier == chat_model:
|
||||
if model.model_type != ModelType.llm:
|
||||
raise ValueError(
|
||||
f"Configured query expansion model '{chat_model}' is not an LLM model "
|
||||
f"(found type: {model.model_type}). Query rewriting requires an LLM model."
|
||||
)
|
||||
model_found = True
|
||||
break
|
||||
|
||||
if not model_found:
|
||||
available_llm_models = [m.identifier for m in models_response.data if m.model_type == ModelType.llm]
|
||||
raise ValueError(
|
||||
f"Configured query expansion model '{chat_model}' is not available. "
|
||||
f"Available LLM models: {available_llm_models}"
|
||||
)
|
||||
|
||||
# Use the configured prompt (has a default value)
|
||||
rewrite_prompt = self.vector_stores_config.query_expansion_prompt.format(query=query)
|
||||
|
||||
chat_request = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=chat_model,
|
||||
messages=[
|
||||
OpenAIUserMessageParam(
|
||||
role="user",
|
||||
content=rewrite_prompt,
|
||||
)
|
||||
],
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.inference_api.openai_chat_completion(chat_request)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate rewritten query: {e}") from e
|
||||
|
||||
if response.choices and len(response.choices) > 0:
|
||||
rewritten_query = response.choices[0].message.content.strip()
|
||||
log.info(f"Query rewritten: '{query}' → '{rewritten_query}'")
|
||||
return rewritten_query
|
||||
else:
|
||||
raise RuntimeError("No response received from LLM model for query rewriting")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue