renaming to query_rewrite, consolidating, and cleaning up validation

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-11-21 23:38:13 -05:00
parent d887f1f8bb
commit 31e28b6d17
12 changed files with 138 additions and 180 deletions

View file

@ -366,6 +366,27 @@ class QualifiedModel(BaseModel):
model_id: str
class RewriteQueryParams(BaseModel):
"""Parameters for query rewriting/expansion."""
model: QualifiedModel | None = Field(
default=None,
description="LLM model for query rewriting/expansion in vector search.",
)
prompt: str = Field(
default=DEFAULT_QUERY_EXPANSION_PROMPT,
description="Prompt template for query rewriting. Use {query} as placeholder for the original query.",
)
max_tokens: int = Field(
default=100,
description="Maximum number of tokens for query expansion responses.",
)
temperature: float = Field(
default=0.3,
description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).",
)
class VectorStoresConfig(BaseModel):
"""Configuration for vector stores in the stack."""
@ -377,21 +398,9 @@ class VectorStoresConfig(BaseModel):
default=None,
description="Default embedding model configuration for vector stores.",
)
default_query_expansion_model: QualifiedModel | None = Field(
rewrite_query_params: RewriteQueryParams | None = Field(
default=None,
description="Default LLM model for query expansion/rewriting in vector search.",
)
query_expansion_prompt: str = Field(
default=DEFAULT_QUERY_EXPANSION_PROMPT,
description="Prompt template for query expansion. Use {query} as placeholder for the original query.",
)
query_expansion_max_tokens: int = Field(
default=100,
description="Maximum number of tokens for query expansion responses.",
)
query_expansion_temperature: float = Field(
default=0.3,
description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).",
description="Parameters for query rewriting/expansion. None disables query rewriting.",
)

View file

@ -14,7 +14,7 @@ from typing import Any
import yaml
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
from llama_stack.core.datatypes import Provider, QualifiedModel, SafetyConfig, StackRunConfig, VectorStoresConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
@ -145,61 +145,67 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig
return
# Validate default embedding model
default_embedding_model = vector_stores_config.default_embedding_model
if default_embedding_model is not None:
provider_id = default_embedding_model.provider_id
model_id = default_embedding_model.model_id
default_model_id = f"{provider_id}/{model_id}"
if vector_stores_config.default_embedding_model is not None:
await _validate_embedding_model(vector_stores_config.default_embedding_model, impls)
if Api.models not in impls:
raise ValueError(
f"Models API is not available but vector_stores config requires model '{default_model_id}'"
)
# Validate default rewrite query model
if vector_stores_config.rewrite_query_params and vector_stores_config.rewrite_query_params.model:
await _validate_rewrite_query_model(vector_stores_config.rewrite_query_params.model, impls)
models_impl = impls[Api.models]
response = await models_impl.list_models()
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
default_model = models_list.get(default_model_id)
if default_model is None:
raise ValueError(
f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}"
)
async def _validate_embedding_model(embedding_model: QualifiedModel, impls: dict[Api, Any]) -> None:
"""Validate that an embedding model exists and has required metadata."""
provider_id = embedding_model.provider_id
model_id = embedding_model.model_id
model_identifier = f"{provider_id}/{model_id}"
embedding_dimension = default_model.metadata.get("embedding_dimension")
if embedding_dimension is None:
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
if Api.models not in impls:
raise ValueError(f"Models API is not available but vector_stores config requires model '{model_identifier}'")
try:
int(embedding_dimension)
except ValueError as err:
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
models_impl = impls[Api.models]
response = await models_impl.list_models()
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
model = models_list.get(model_identifier)
if model is None:
raise ValueError(
f"Embedding model '{model_identifier}' not found. Available embedding models: {list(models_list.keys())}"
)
# Validate default query expansion model
default_query_expansion_model = vector_stores_config.default_query_expansion_model
if default_query_expansion_model is not None:
provider_id = default_query_expansion_model.provider_id
model_id = default_query_expansion_model.model_id
query_model_id = f"{provider_id}/{model_id}"
embedding_dimension = model.metadata.get("embedding_dimension")
if embedding_dimension is None:
raise ValueError(f"Embedding model '{model_identifier}' is missing 'embedding_dimension' in metadata")
if Api.models not in impls:
raise ValueError(
f"Models API is not available but vector_stores config requires query expansion model '{query_model_id}'"
)
try:
int(embedding_dimension)
except ValueError as err:
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
models_impl = impls[Api.models]
response = await models_impl.list_models()
llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"}
logger.debug(f"Validated embedding model: {model_identifier} (dimension: {embedding_dimension})")
query_expansion_model = llm_models_list.get(query_model_id)
if query_expansion_model is None:
raise ValueError(
f"Query expansion model '{query_model_id}' not found. Available LLM models: {list(llm_models_list.keys())}"
)
logger.debug(f"Validated default query expansion model: {query_model_id}")
async def _validate_rewrite_query_model(rewrite_query_model: QualifiedModel, impls: dict[Api, Any]) -> None:
"""Validate that a rewrite query model exists and is accessible."""
provider_id = rewrite_query_model.provider_id
model_id = rewrite_query_model.model_id
model_identifier = f"{provider_id}/{model_id}"
if Api.models not in impls:
raise ValueError(
f"Models API is not available but vector_stores config requires rewrite query model '{model_identifier}'"
)
models_impl = impls[Api.models]
response = await models_impl.list_models()
llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"}
model = llm_models_list.get(model_identifier)
if model is None:
raise ValueError(
f"Rewrite query model '{model_identifier}' not found. Available LLM models: {list(llm_models_list.keys())}"
)
logger.debug(f"Validated rewrite query model: {model_identifier}")
async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]):
@ -466,9 +472,9 @@ class Stack:
await validate_safety_config(self.run_config.safety, impls)
# Set global query expansion configuration from stack config
from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config
from llama_stack.providers.utils.memory.rewrite_query_config import set_default_rewrite_query_config
set_default_query_expansion_config(self.run_config.vector_stores)
set_default_rewrite_query_config(self.run_config.vector_stores)
self.impls = impls

View file

@ -288,15 +288,5 @@ vector_stores:
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
Return only the improved query, no explanations:
{query}
Improved query:'
query_expansion_max_tokens: 100
query_expansion_temperature: 0.3
safety:
default_shield_id: llama-guard

View file

@ -279,15 +279,5 @@ vector_stores:
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
Return only the improved query, no explanations:
{query}
Improved query:'
query_expansion_max_tokens: 100
query_expansion_temperature: 0.3
safety:
default_shield_id: llama-guard

View file

@ -291,15 +291,5 @@ vector_stores:
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
Return only the improved query, no explanations:
{query}
Improved query:'
query_expansion_max_tokens: 100
query_expansion_temperature: 0.3
safety:
default_shield_id: llama-guard

View file

@ -282,15 +282,5 @@ vector_stores:
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
Return only the improved query, no explanations:
{query}
Improved query:'
query_expansion_max_tokens: 100
query_expansion_temperature: 0.3
safety:
default_shield_id: llama-guard

View file

@ -288,15 +288,5 @@ vector_stores:
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
Return only the improved query, no explanations:
{query}
Improved query:'
query_expansion_max_tokens: 100
query_expansion_temperature: 0.3
safety:
default_shield_id: llama-guard

View file

@ -279,15 +279,5 @@ vector_stores:
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
Return only the improved query, no explanations:
{query}
Improved query:'
query_expansion_max_tokens: 100
query_expansion_temperature: 0.3
safety:
default_shield_id: llama-guard

View file

@ -1,37 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
# Global configuration for query expansion - set during stack startup
_DEFAULT_QUERY_EXPANSION_MODEL: QualifiedModel | None = None
_DEFAULT_QUERY_EXPANSION_MAX_TOKENS: int = 100
_DEFAULT_QUERY_EXPANSION_TEMPERATURE: float = 0.3
_QUERY_EXPANSION_PROMPT_OVERRIDE: str | None = None
def set_default_query_expansion_config(vector_stores_config: VectorStoresConfig | None):
"""Set the global default query expansion configuration from stack config."""
global \
_DEFAULT_QUERY_EXPANSION_MODEL, \
_QUERY_EXPANSION_PROMPT_OVERRIDE, \
_DEFAULT_QUERY_EXPANSION_MAX_TOKENS, \
_DEFAULT_QUERY_EXPANSION_TEMPERATURE
if vector_stores_config:
_DEFAULT_QUERY_EXPANSION_MODEL = vector_stores_config.default_query_expansion_model
# Only set override if user provided a custom prompt different from default
if vector_stores_config.query_expansion_prompt != DEFAULT_QUERY_EXPANSION_PROMPT:
_QUERY_EXPANSION_PROMPT_OVERRIDE = vector_stores_config.query_expansion_prompt
else:
_QUERY_EXPANSION_PROMPT_OVERRIDE = None
_DEFAULT_QUERY_EXPANSION_MAX_TOKENS = vector_stores_config.query_expansion_max_tokens
_DEFAULT_QUERY_EXPANSION_TEMPERATURE = vector_stores_config.query_expansion_temperature
else:
_DEFAULT_QUERY_EXPANSION_MODEL = None
_QUERY_EXPANSION_PROMPT_OVERRIDE = None
_DEFAULT_QUERY_EXPANSION_MAX_TOKENS = 100
_DEFAULT_QUERY_EXPANSION_TEMPERATURE = 0.3

View file

@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
# Global configuration for query rewriting - set during stack startup
_DEFAULT_REWRITE_QUERY_MODEL: QualifiedModel | None = None
_DEFAULT_REWRITE_QUERY_MAX_TOKENS: int = 100
_DEFAULT_REWRITE_QUERY_TEMPERATURE: float = 0.3
_REWRITE_QUERY_PROMPT_OVERRIDE: str | None = None
def set_default_rewrite_query_config(vector_stores_config: VectorStoresConfig | None):
"""Set the global default query rewriting configuration from stack config."""
global \
_DEFAULT_REWRITE_QUERY_MODEL, \
_REWRITE_QUERY_PROMPT_OVERRIDE, \
_DEFAULT_REWRITE_QUERY_MAX_TOKENS, \
_DEFAULT_REWRITE_QUERY_TEMPERATURE
if vector_stores_config and vector_stores_config.rewrite_query_params:
params = vector_stores_config.rewrite_query_params
_DEFAULT_REWRITE_QUERY_MODEL = params.model
# Only set override if user provided a custom prompt different from default
if params.prompt != DEFAULT_QUERY_EXPANSION_PROMPT:
_REWRITE_QUERY_PROMPT_OVERRIDE = params.prompt
else:
_REWRITE_QUERY_PROMPT_OVERRIDE = None
_DEFAULT_REWRITE_QUERY_MAX_TOKENS = params.max_tokens
_DEFAULT_REWRITE_QUERY_TEMPERATURE = params.temperature
else:
_DEFAULT_REWRITE_QUERY_MODEL = None
_REWRITE_QUERY_PROMPT_OVERRIDE = None
_DEFAULT_REWRITE_QUERY_MAX_TOKENS = 100
_DEFAULT_REWRITE_QUERY_TEMPERATURE = 0.3

View file

@ -38,7 +38,7 @@ from llama_stack_api import (
log = get_logger(name=__name__, category="providers::utils")
from llama_stack.providers.utils.memory import query_expansion_config
from llama_stack.providers.utils.memory import rewrite_query_config
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
@ -295,20 +295,20 @@ class VectorStoreWithIndex:
async def _rewrite_query_for_file_search(self, query: str) -> str:
"""Rewrite a search query using the globally configured LLM model for better retrieval results."""
if not query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL:
log.debug("No default query expansion model configured, using original query")
if not rewrite_query_config._DEFAULT_REWRITE_QUERY_MODEL:
log.debug("No default query rewriting model configured, using original query")
return query
model_id = f"{query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL.provider_id}/{query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL.model_id}"
model_id = f"{rewrite_query_config._DEFAULT_REWRITE_QUERY_MODEL.provider_id}/{rewrite_query_config._DEFAULT_REWRITE_QUERY_MODEL.model_id}"
# Use custom prompt from config if provided, otherwise use built-in default
# Users only need to configure the model - prompt is automatic with optional override
if query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE:
if rewrite_query_config._REWRITE_QUERY_PROMPT_OVERRIDE:
# Custom prompt from config - format if it contains {query} placeholder
prompt = (
query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE.format(query=query)
if "{query}" in query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE
else query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE
rewrite_query_config._REWRITE_QUERY_PROMPT_OVERRIDE.format(query=query)
if "{query}" in rewrite_query_config._REWRITE_QUERY_PROMPT_OVERRIDE
else rewrite_query_config._REWRITE_QUERY_PROMPT_OVERRIDE
)
else:
# Use built-in default prompt and format with query
@ -317,8 +317,8 @@ class VectorStoreWithIndex:
request = OpenAIChatCompletionRequestWithExtraBody(
model=model_id,
messages=[{"role": "user", "content": prompt}],
max_tokens=query_expansion_config._DEFAULT_QUERY_EXPANSION_MAX_TOKENS,
temperature=query_expansion_config._DEFAULT_QUERY_EXPANSION_TEMPERATURE,
max_tokens=rewrite_query_config._DEFAULT_REWRITE_QUERY_MAX_TOKENS,
temperature=rewrite_query_config._DEFAULT_REWRITE_QUERY_TEMPERATURE,
)
response = await self.inference_api.openai_chat_completion(request)

View file

@ -1236,9 +1236,9 @@ async def test_query_expansion_functionality(vector_io_adapter):
"""Test query expansion with simplified global configuration approach."""
from unittest.mock import MagicMock
from llama_stack.core.datatypes import QualifiedModel
from llama_stack.core.datatypes import QualifiedModel, RewriteQueryParams
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config
from llama_stack.providers.utils.memory.rewrite_query_config import set_default_rewrite_query_config
from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex
from llama_stack_api import QueryChunksResponse
@ -1266,13 +1266,12 @@ async def test_query_expansion_functionality(vector_io_adapter):
# Test 1: Query expansion with default prompt (no custom prompt configured)
mock_vector_stores_config = MagicMock()
mock_vector_stores_config.default_query_expansion_model = QualifiedModel(provider_id="test", model_id="llama")
mock_vector_stores_config.query_expansion_prompt = None # Use built-in default prompt
mock_vector_stores_config.query_expansion_max_tokens = 100 # Default value
mock_vector_stores_config.query_expansion_temperature = 0.3 # Default value
mock_vector_stores_config.rewrite_query_params = RewriteQueryParams(
model=QualifiedModel(provider_id="test", model_id="llama"), max_tokens=100, temperature=0.3
)
# Set global config
set_default_query_expansion_config(mock_vector_stores_config)
set_default_rewrite_query_config(mock_vector_stores_config)
# Mock chat completion for query rewriting
mock_inference_api.openai_chat_completion = AsyncMock(
@ -1305,10 +1304,13 @@ async def test_query_expansion_functionality(vector_io_adapter):
mock_inference_api.reset_mock()
mock_index.reset_mock()
mock_vector_stores_config.query_expansion_prompt = "Custom prompt for rewriting: {query}"
mock_vector_stores_config.query_expansion_max_tokens = 150
mock_vector_stores_config.query_expansion_temperature = 0.7
set_default_query_expansion_config(mock_vector_stores_config)
mock_vector_stores_config.rewrite_query_params = RewriteQueryParams(
model=QualifiedModel(provider_id="test", model_id="llama"),
prompt="Custom prompt for rewriting: {query}",
max_tokens=150,
temperature=0.7,
)
set_default_rewrite_query_config(mock_vector_stores_config)
result = await vector_store_with_index.query_chunks("test query", params)
@ -1328,7 +1330,7 @@ async def test_query_expansion_functionality(vector_io_adapter):
mock_index.reset_mock()
# Clear global config
set_default_query_expansion_config(None)
set_default_rewrite_query_config(None)
params = {"rewrite_query": True, "max_chunks": 5}
result2 = await vector_store_with_index.query_chunks("test query", params)