diff --git a/src/llama_stack/core/datatypes.py b/src/llama_stack/core/datatypes.py index a32e1d8a2..8fab715f2 100644 --- a/src/llama_stack/core/datatypes.py +++ b/src/llama_stack/core/datatypes.py @@ -366,6 +366,27 @@ class QualifiedModel(BaseModel): model_id: str +class RewriteQueryParams(BaseModel): + """Parameters for query rewriting/expansion.""" + + model: QualifiedModel | None = Field( + default=None, + description="LLM model for query rewriting/expansion in vector search.", + ) + prompt: str = Field( + default=DEFAULT_QUERY_EXPANSION_PROMPT, + description="Prompt template for query rewriting. Use {query} as placeholder for the original query.", + ) + max_tokens: int = Field( + default=100, + description="Maximum number of tokens for query expansion responses.", + ) + temperature: float = Field( + default=0.3, + description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).", + ) + + class VectorStoresConfig(BaseModel): """Configuration for vector stores in the stack.""" @@ -377,21 +398,9 @@ class VectorStoresConfig(BaseModel): default=None, description="Default embedding model configuration for vector stores.", ) - default_query_expansion_model: QualifiedModel | None = Field( + rewrite_query_params: RewriteQueryParams | None = Field( default=None, - description="Default LLM model for query expansion/rewriting in vector search.", - ) - query_expansion_prompt: str = Field( - default=DEFAULT_QUERY_EXPANSION_PROMPT, - description="Prompt template for query expansion. Use {query} as placeholder for the original query.", - ) - query_expansion_max_tokens: int = Field( - default=100, - description="Maximum number of tokens for query expansion responses.", - ) - query_expansion_temperature: float = Field( - default=0.3, - description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).", + description="Parameters for query rewriting/expansion. None disables query rewriting.", ) diff --git a/src/llama_stack/core/stack.py b/src/llama_stack/core/stack.py index dae6e8ec9..0bebf800d 100644 --- a/src/llama_stack/core/stack.py +++ b/src/llama_stack/core/stack.py @@ -14,7 +14,7 @@ from typing import Any import yaml from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl -from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig +from llama_stack.core.datatypes import Provider, QualifiedModel, SafetyConfig, StackRunConfig, VectorStoresConfig from llama_stack.core.distribution import get_provider_registry from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl @@ -145,61 +145,67 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig return # Validate default embedding model - default_embedding_model = vector_stores_config.default_embedding_model - if default_embedding_model is not None: - provider_id = default_embedding_model.provider_id - model_id = default_embedding_model.model_id - default_model_id = f"{provider_id}/{model_id}" + if vector_stores_config.default_embedding_model is not None: + await _validate_embedding_model(vector_stores_config.default_embedding_model, impls) - if Api.models not in impls: - raise ValueError( - f"Models API is not available but vector_stores config requires model '{default_model_id}'" - ) + # Validate default rewrite query model + if vector_stores_config.rewrite_query_params and vector_stores_config.rewrite_query_params.model: + await _validate_rewrite_query_model(vector_stores_config.rewrite_query_params.model, impls) - models_impl = impls[Api.models] - response = await models_impl.list_models() - models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"} - default_model = models_list.get(default_model_id) - if default_model is None: - raise ValueError( - f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}" - ) +async def _validate_embedding_model(embedding_model: QualifiedModel, impls: dict[Api, Any]) -> None: + """Validate that an embedding model exists and has required metadata.""" + provider_id = embedding_model.provider_id + model_id = embedding_model.model_id + model_identifier = f"{provider_id}/{model_id}" - embedding_dimension = default_model.metadata.get("embedding_dimension") - if embedding_dimension is None: - raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata") + if Api.models not in impls: + raise ValueError(f"Models API is not available but vector_stores config requires model '{model_identifier}'") - try: - int(embedding_dimension) - except ValueError as err: - raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err + models_impl = impls[Api.models] + response = await models_impl.list_models() + models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"} - logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})") + model = models_list.get(model_identifier) + if model is None: + raise ValueError( + f"Embedding model '{model_identifier}' not found. Available embedding models: {list(models_list.keys())}" + ) - # Validate default query expansion model - default_query_expansion_model = vector_stores_config.default_query_expansion_model - if default_query_expansion_model is not None: - provider_id = default_query_expansion_model.provider_id - model_id = default_query_expansion_model.model_id - query_model_id = f"{provider_id}/{model_id}" + embedding_dimension = model.metadata.get("embedding_dimension") + if embedding_dimension is None: + raise ValueError(f"Embedding model '{model_identifier}' is missing 'embedding_dimension' in metadata") - if Api.models not in impls: - raise ValueError( - f"Models API is not available but vector_stores config requires query expansion model '{query_model_id}'" - ) + try: + int(embedding_dimension) + except ValueError as err: + raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err - models_impl = impls[Api.models] - response = await models_impl.list_models() - llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"} + logger.debug(f"Validated embedding model: {model_identifier} (dimension: {embedding_dimension})") - query_expansion_model = llm_models_list.get(query_model_id) - if query_expansion_model is None: - raise ValueError( - f"Query expansion model '{query_model_id}' not found. Available LLM models: {list(llm_models_list.keys())}" - ) - logger.debug(f"Validated default query expansion model: {query_model_id}") +async def _validate_rewrite_query_model(rewrite_query_model: QualifiedModel, impls: dict[Api, Any]) -> None: + """Validate that a rewrite query model exists and is accessible.""" + provider_id = rewrite_query_model.provider_id + model_id = rewrite_query_model.model_id + model_identifier = f"{provider_id}/{model_id}" + + if Api.models not in impls: + raise ValueError( + f"Models API is not available but vector_stores config requires rewrite query model '{model_identifier}'" + ) + + models_impl = impls[Api.models] + response = await models_impl.list_models() + llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"} + + model = llm_models_list.get(model_identifier) + if model is None: + raise ValueError( + f"Rewrite query model '{model_identifier}' not found. Available LLM models: {list(llm_models_list.keys())}" + ) + + logger.debug(f"Validated rewrite query model: {model_identifier}") async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]): @@ -466,9 +472,9 @@ class Stack: await validate_safety_config(self.run_config.safety, impls) # Set global query expansion configuration from stack config - from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config + from llama_stack.providers.utils.memory.rewrite_query_config import set_default_rewrite_query_config - set_default_query_expansion_config(self.run_config.vector_stores) + set_default_rewrite_query_config(self.run_config.vector_stores) self.impls = impls diff --git a/src/llama_stack/distributions/ci-tests/run-with-postgres-store.yaml b/src/llama_stack/distributions/ci-tests/run-with-postgres-store.yaml index 219ffdce3..7721138c7 100644 --- a/src/llama_stack/distributions/ci-tests/run-with-postgres-store.yaml +++ b/src/llama_stack/distributions/ci-tests/run-with-postgres-store.yaml @@ -288,15 +288,5 @@ vector_stores: default_embedding_model: provider_id: sentence-transformers model_id: nomic-ai/nomic-embed-text-v1.5 - query_expansion_prompt: 'Expand this query with relevant synonyms and related terms. - Return only the improved query, no explanations: - - - {query} - - - Improved query:' - query_expansion_max_tokens: 100 - query_expansion_temperature: 0.3 safety: default_shield_id: llama-guard diff --git a/src/llama_stack/distributions/ci-tests/run.yaml b/src/llama_stack/distributions/ci-tests/run.yaml index e352e9268..b791e1488 100644 --- a/src/llama_stack/distributions/ci-tests/run.yaml +++ b/src/llama_stack/distributions/ci-tests/run.yaml @@ -279,15 +279,5 @@ vector_stores: default_embedding_model: provider_id: sentence-transformers model_id: nomic-ai/nomic-embed-text-v1.5 - query_expansion_prompt: 'Expand this query with relevant synonyms and related terms. - Return only the improved query, no explanations: - - - {query} - - - Improved query:' - query_expansion_max_tokens: 100 - query_expansion_temperature: 0.3 safety: default_shield_id: llama-guard diff --git a/src/llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml b/src/llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml index e81febb0e..9c250c05a 100644 --- a/src/llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +++ b/src/llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml @@ -291,15 +291,5 @@ vector_stores: default_embedding_model: provider_id: sentence-transformers model_id: nomic-ai/nomic-embed-text-v1.5 - query_expansion_prompt: 'Expand this query with relevant synonyms and related terms. - Return only the improved query, no explanations: - - - {query} - - - Improved query:' - query_expansion_max_tokens: 100 - query_expansion_temperature: 0.3 safety: default_shield_id: llama-guard diff --git a/src/llama_stack/distributions/starter-gpu/run.yaml b/src/llama_stack/distributions/starter-gpu/run.yaml index edae6f66d..65f9ae326 100644 --- a/src/llama_stack/distributions/starter-gpu/run.yaml +++ b/src/llama_stack/distributions/starter-gpu/run.yaml @@ -282,15 +282,5 @@ vector_stores: default_embedding_model: provider_id: sentence-transformers model_id: nomic-ai/nomic-embed-text-v1.5 - query_expansion_prompt: 'Expand this query with relevant synonyms and related terms. - Return only the improved query, no explanations: - - - {query} - - - Improved query:' - query_expansion_max_tokens: 100 - query_expansion_temperature: 0.3 safety: default_shield_id: llama-guard diff --git a/src/llama_stack/distributions/starter/run-with-postgres-store.yaml b/src/llama_stack/distributions/starter/run-with-postgres-store.yaml index 9ed74d96d..3314bb9e9 100644 --- a/src/llama_stack/distributions/starter/run-with-postgres-store.yaml +++ b/src/llama_stack/distributions/starter/run-with-postgres-store.yaml @@ -288,15 +288,5 @@ vector_stores: default_embedding_model: provider_id: sentence-transformers model_id: nomic-ai/nomic-embed-text-v1.5 - query_expansion_prompt: 'Expand this query with relevant synonyms and related terms. - Return only the improved query, no explanations: - - - {query} - - - Improved query:' - query_expansion_max_tokens: 100 - query_expansion_temperature: 0.3 safety: default_shield_id: llama-guard diff --git a/src/llama_stack/distributions/starter/run.yaml b/src/llama_stack/distributions/starter/run.yaml index 73679a152..e88539e6a 100644 --- a/src/llama_stack/distributions/starter/run.yaml +++ b/src/llama_stack/distributions/starter/run.yaml @@ -279,15 +279,5 @@ vector_stores: default_embedding_model: provider_id: sentence-transformers model_id: nomic-ai/nomic-embed-text-v1.5 - query_expansion_prompt: 'Expand this query with relevant synonyms and related terms. - Return only the improved query, no explanations: - - - {query} - - - Improved query:' - query_expansion_max_tokens: 100 - query_expansion_temperature: 0.3 safety: default_shield_id: llama-guard diff --git a/src/llama_stack/providers/utils/memory/query_expansion_config.py b/src/llama_stack/providers/utils/memory/query_expansion_config.py deleted file mode 100644 index 0b51c1a9a..000000000 --- a/src/llama_stack/providers/utils/memory/query_expansion_config.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig -from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT - -# Global configuration for query expansion - set during stack startup -_DEFAULT_QUERY_EXPANSION_MODEL: QualifiedModel | None = None -_DEFAULT_QUERY_EXPANSION_MAX_TOKENS: int = 100 -_DEFAULT_QUERY_EXPANSION_TEMPERATURE: float = 0.3 -_QUERY_EXPANSION_PROMPT_OVERRIDE: str | None = None - - -def set_default_query_expansion_config(vector_stores_config: VectorStoresConfig | None): - """Set the global default query expansion configuration from stack config.""" - global \ - _DEFAULT_QUERY_EXPANSION_MODEL, \ - _QUERY_EXPANSION_PROMPT_OVERRIDE, \ - _DEFAULT_QUERY_EXPANSION_MAX_TOKENS, \ - _DEFAULT_QUERY_EXPANSION_TEMPERATURE - if vector_stores_config: - _DEFAULT_QUERY_EXPANSION_MODEL = vector_stores_config.default_query_expansion_model - # Only set override if user provided a custom prompt different from default - if vector_stores_config.query_expansion_prompt != DEFAULT_QUERY_EXPANSION_PROMPT: - _QUERY_EXPANSION_PROMPT_OVERRIDE = vector_stores_config.query_expansion_prompt - else: - _QUERY_EXPANSION_PROMPT_OVERRIDE = None - _DEFAULT_QUERY_EXPANSION_MAX_TOKENS = vector_stores_config.query_expansion_max_tokens - _DEFAULT_QUERY_EXPANSION_TEMPERATURE = vector_stores_config.query_expansion_temperature - else: - _DEFAULT_QUERY_EXPANSION_MODEL = None - _QUERY_EXPANSION_PROMPT_OVERRIDE = None - _DEFAULT_QUERY_EXPANSION_MAX_TOKENS = 100 - _DEFAULT_QUERY_EXPANSION_TEMPERATURE = 0.3 diff --git a/src/llama_stack/providers/utils/memory/rewrite_query_config.py b/src/llama_stack/providers/utils/memory/rewrite_query_config.py new file mode 100644 index 000000000..9c53638b8 --- /dev/null +++ b/src/llama_stack/providers/utils/memory/rewrite_query_config.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig +from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT + +# Global configuration for query rewriting - set during stack startup +_DEFAULT_REWRITE_QUERY_MODEL: QualifiedModel | None = None +_DEFAULT_REWRITE_QUERY_MAX_TOKENS: int = 100 +_DEFAULT_REWRITE_QUERY_TEMPERATURE: float = 0.3 +_REWRITE_QUERY_PROMPT_OVERRIDE: str | None = None + + +def set_default_rewrite_query_config(vector_stores_config: VectorStoresConfig | None): + """Set the global default query rewriting configuration from stack config.""" + global \ + _DEFAULT_REWRITE_QUERY_MODEL, \ + _REWRITE_QUERY_PROMPT_OVERRIDE, \ + _DEFAULT_REWRITE_QUERY_MAX_TOKENS, \ + _DEFAULT_REWRITE_QUERY_TEMPERATURE + if vector_stores_config and vector_stores_config.rewrite_query_params: + params = vector_stores_config.rewrite_query_params + _DEFAULT_REWRITE_QUERY_MODEL = params.model + # Only set override if user provided a custom prompt different from default + if params.prompt != DEFAULT_QUERY_EXPANSION_PROMPT: + _REWRITE_QUERY_PROMPT_OVERRIDE = params.prompt + else: + _REWRITE_QUERY_PROMPT_OVERRIDE = None + _DEFAULT_REWRITE_QUERY_MAX_TOKENS = params.max_tokens + _DEFAULT_REWRITE_QUERY_TEMPERATURE = params.temperature + else: + _DEFAULT_REWRITE_QUERY_MODEL = None + _REWRITE_QUERY_PROMPT_OVERRIDE = None + _DEFAULT_REWRITE_QUERY_MAX_TOKENS = 100 + _DEFAULT_REWRITE_QUERY_TEMPERATURE = 0.3 diff --git a/src/llama_stack/providers/utils/memory/vector_store.py b/src/llama_stack/providers/utils/memory/vector_store.py index 61fa996e4..11754bae2 100644 --- a/src/llama_stack/providers/utils/memory/vector_store.py +++ b/src/llama_stack/providers/utils/memory/vector_store.py @@ -38,7 +38,7 @@ from llama_stack_api import ( log = get_logger(name=__name__, category="providers::utils") -from llama_stack.providers.utils.memory import query_expansion_config +from llama_stack.providers.utils.memory import rewrite_query_config from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT @@ -295,20 +295,20 @@ class VectorStoreWithIndex: async def _rewrite_query_for_file_search(self, query: str) -> str: """Rewrite a search query using the globally configured LLM model for better retrieval results.""" - if not query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL: - log.debug("No default query expansion model configured, using original query") + if not rewrite_query_config._DEFAULT_REWRITE_QUERY_MODEL: + log.debug("No default query rewriting model configured, using original query") return query - model_id = f"{query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL.provider_id}/{query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL.model_id}" + model_id = f"{rewrite_query_config._DEFAULT_REWRITE_QUERY_MODEL.provider_id}/{rewrite_query_config._DEFAULT_REWRITE_QUERY_MODEL.model_id}" # Use custom prompt from config if provided, otherwise use built-in default # Users only need to configure the model - prompt is automatic with optional override - if query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE: + if rewrite_query_config._REWRITE_QUERY_PROMPT_OVERRIDE: # Custom prompt from config - format if it contains {query} placeholder prompt = ( - query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE.format(query=query) - if "{query}" in query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE - else query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE + rewrite_query_config._REWRITE_QUERY_PROMPT_OVERRIDE.format(query=query) + if "{query}" in rewrite_query_config._REWRITE_QUERY_PROMPT_OVERRIDE + else rewrite_query_config._REWRITE_QUERY_PROMPT_OVERRIDE ) else: # Use built-in default prompt and format with query @@ -317,8 +317,8 @@ class VectorStoreWithIndex: request = OpenAIChatCompletionRequestWithExtraBody( model=model_id, messages=[{"role": "user", "content": prompt}], - max_tokens=query_expansion_config._DEFAULT_QUERY_EXPANSION_MAX_TOKENS, - temperature=query_expansion_config._DEFAULT_QUERY_EXPANSION_TEMPERATURE, + max_tokens=rewrite_query_config._DEFAULT_REWRITE_QUERY_MAX_TOKENS, + temperature=rewrite_query_config._DEFAULT_REWRITE_QUERY_TEMPERATURE, ) response = await self.inference_api.openai_chat_completion(request) diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index 83bf22f34..07ec41bec 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -1236,9 +1236,9 @@ async def test_query_expansion_functionality(vector_io_adapter): """Test query expansion with simplified global configuration approach.""" from unittest.mock import MagicMock - from llama_stack.core.datatypes import QualifiedModel + from llama_stack.core.datatypes import QualifiedModel, RewriteQueryParams from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT - from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config + from llama_stack.providers.utils.memory.rewrite_query_config import set_default_rewrite_query_config from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex from llama_stack_api import QueryChunksResponse @@ -1266,13 +1266,12 @@ async def test_query_expansion_functionality(vector_io_adapter): # Test 1: Query expansion with default prompt (no custom prompt configured) mock_vector_stores_config = MagicMock() - mock_vector_stores_config.default_query_expansion_model = QualifiedModel(provider_id="test", model_id="llama") - mock_vector_stores_config.query_expansion_prompt = None # Use built-in default prompt - mock_vector_stores_config.query_expansion_max_tokens = 100 # Default value - mock_vector_stores_config.query_expansion_temperature = 0.3 # Default value + mock_vector_stores_config.rewrite_query_params = RewriteQueryParams( + model=QualifiedModel(provider_id="test", model_id="llama"), max_tokens=100, temperature=0.3 + ) # Set global config - set_default_query_expansion_config(mock_vector_stores_config) + set_default_rewrite_query_config(mock_vector_stores_config) # Mock chat completion for query rewriting mock_inference_api.openai_chat_completion = AsyncMock( @@ -1305,10 +1304,13 @@ async def test_query_expansion_functionality(vector_io_adapter): mock_inference_api.reset_mock() mock_index.reset_mock() - mock_vector_stores_config.query_expansion_prompt = "Custom prompt for rewriting: {query}" - mock_vector_stores_config.query_expansion_max_tokens = 150 - mock_vector_stores_config.query_expansion_temperature = 0.7 - set_default_query_expansion_config(mock_vector_stores_config) + mock_vector_stores_config.rewrite_query_params = RewriteQueryParams( + model=QualifiedModel(provider_id="test", model_id="llama"), + prompt="Custom prompt for rewriting: {query}", + max_tokens=150, + temperature=0.7, + ) + set_default_rewrite_query_config(mock_vector_stores_config) result = await vector_store_with_index.query_chunks("test query", params) @@ -1328,7 +1330,7 @@ async def test_query_expansion_functionality(vector_io_adapter): mock_index.reset_mock() # Clear global config - set_default_query_expansion_config(None) + set_default_rewrite_query_config(None) params = {"rewrite_query": True, "max_chunks": 5} result2 = await vector_store_with_index.query_chunks("test query", params)