From d887f1f8bb8d31d3af599fb510edc0b1c6befe7d Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Fri, 21 Nov 2025 11:27:25 -0500 Subject: [PATCH] refactor to only configuration of model at build time Signed-off-by: Francisco Javier Arceo --- src/llama_stack/core/datatypes.py | 11 +- src/llama_stack/core/resolver.py | 12 -- src/llama_stack/core/routers/vector_io.py | 13 -- src/llama_stack/core/stack.py | 75 +++++--- .../ci-tests/run-with-postgres-store.yaml | 2 + .../distributions/ci-tests/run.yaml | 2 + .../starter-gpu/run-with-postgres-store.yaml | 2 + .../distributions/starter-gpu/run.yaml | 2 + .../starter/run-with-postgres-store.yaml | 2 + .../distributions/starter/run.yaml | 2 + .../inline/vector_io/faiss/__init__.py | 7 +- .../providers/inline/vector_io/faiss/faiss.py | 6 - .../inline/vector_io/sqlite_vec/__init__.py | 7 +- .../inline/vector_io/sqlite_vec/sqlite_vec.py | 12 +- .../remote/vector_io/chroma/__init__.py | 7 +- .../remote/vector_io/chroma/chroma.py | 9 +- .../remote/vector_io/milvus/__init__.py | 7 +- .../remote/vector_io/milvus/milvus.py | 6 - .../remote/vector_io/pgvector/__init__.py | 4 +- .../remote/vector_io/pgvector/pgvector.py | 9 +- .../remote/vector_io/qdrant/__init__.py | 7 +- .../remote/vector_io/qdrant/qdrant.py | 6 - .../remote/vector_io/weaviate/__init__.py | 4 +- .../remote/vector_io/weaviate/weaviate.py | 6 - .../providers/utils/memory/__init__.py | 4 + .../providers/utils/memory/constants.py | 8 + .../utils/memory/openai_vector_store_mixin.py | 9 - .../utils/memory/query_expansion_config.py | 37 ++++ .../providers/utils/memory/vector_store.py | 146 ++++----------- src/llama_stack_api/vector_stores.py | 1 - .../test_vector_io_openai_vector_stores.py | 170 ++++++++++-------- 31 files changed, 280 insertions(+), 315 deletions(-) create mode 100644 src/llama_stack/providers/utils/memory/constants.py create mode 100644 src/llama_stack/providers/utils/memory/query_expansion_config.py diff --git a/src/llama_stack/core/datatypes.py b/src/llama_stack/core/datatypes.py index 49747d477..a32e1d8a2 100644 --- a/src/llama_stack/core/datatypes.py +++ b/src/llama_stack/core/datatypes.py @@ -18,6 +18,7 @@ from llama_stack.core.storage.datatypes import ( StorageConfig, ) from llama_stack.log import LoggingConfig +from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT from llama_stack_api import ( Api, Benchmark, @@ -381,9 +382,17 @@ class VectorStoresConfig(BaseModel): description="Default LLM model for query expansion/rewriting in vector search.", ) query_expansion_prompt: str = Field( - default="Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:", + default=DEFAULT_QUERY_EXPANSION_PROMPT, description="Prompt template for query expansion. Use {query} as placeholder for the original query.", ) + query_expansion_max_tokens: int = Field( + default=100, + description="Maximum number of tokens for query expansion responses.", + ) + query_expansion_temperature: float = Field( + default=0.3, + description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).", + ) class SafetyConfig(BaseModel): diff --git a/src/llama_stack/core/resolver.py b/src/llama_stack/core/resolver.py index ebdbb0b18..6bc32c2d0 100644 --- a/src/llama_stack/core/resolver.py +++ b/src/llama_stack/core/resolver.py @@ -374,13 +374,6 @@ async def instantiate_provider( method = "get_adapter_impl" args = [config, deps] - # Add vector_stores_config for vector_io providers - if ( - "vector_stores_config" in inspect.signature(getattr(module, method)).parameters - and provider_spec.api == Api.vector_io - ): - args.append(run_config.vector_stores) - elif isinstance(provider_spec, AutoRoutedProviderSpec): method = "get_auto_router_impl" @@ -401,11 +394,6 @@ async def instantiate_provider( args.append(policy) if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry: args.append(run_config.telemetry.enabled) - if ( - "vector_stores_config" in inspect.signature(getattr(module, method)).parameters - and provider_spec.api == Api.vector_io - ): - args.append(run_config.vector_stores) fn = getattr(module, method) impl = await fn(*args) diff --git a/src/llama_stack/core/routers/vector_io.py b/src/llama_stack/core/routers/vector_io.py index a865a3793..5256dda44 100644 --- a/src/llama_stack/core/routers/vector_io.py +++ b/src/llama_stack/core/routers/vector_io.py @@ -99,19 +99,6 @@ class VectorIORouter(VectorIO): ) -> QueryChunksResponse: logger.debug(f"VectorIORouter.query_chunks: {vector_store_id}") provider = await self.routing_table.get_provider_impl(vector_store_id) - - # Ensure params dict exists and add vector_stores_config for query rewriting - if params is None: - params = {} - - logger.debug(f"Router vector_stores_config: {self.vector_stores_config}") - if self.vector_stores_config and hasattr(self.vector_stores_config, "default_query_expansion_model"): - logger.debug( - f"Router default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}" - ) - - params["vector_stores_config"] = self.vector_stores_config - return await provider.query_chunks(vector_store_id, query, params) # OpenAI Vector Stores API endpoints diff --git a/src/llama_stack/core/stack.py b/src/llama_stack/core/stack.py index 8ba1f2afd..dae6e8ec9 100644 --- a/src/llama_stack/core/stack.py +++ b/src/llama_stack/core/stack.py @@ -144,35 +144,62 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig if vector_stores_config is None: return + # Validate default embedding model default_embedding_model = vector_stores_config.default_embedding_model - if default_embedding_model is None: - return + if default_embedding_model is not None: + provider_id = default_embedding_model.provider_id + model_id = default_embedding_model.model_id + default_model_id = f"{provider_id}/{model_id}" - provider_id = default_embedding_model.provider_id - model_id = default_embedding_model.model_id - default_model_id = f"{provider_id}/{model_id}" + if Api.models not in impls: + raise ValueError( + f"Models API is not available but vector_stores config requires model '{default_model_id}'" + ) - if Api.models not in impls: - raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'") + models_impl = impls[Api.models] + response = await models_impl.list_models() + models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"} - models_impl = impls[Api.models] - response = await models_impl.list_models() - models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"} + default_model = models_list.get(default_model_id) + if default_model is None: + raise ValueError( + f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}" + ) - default_model = models_list.get(default_model_id) - if default_model is None: - raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}") + embedding_dimension = default_model.metadata.get("embedding_dimension") + if embedding_dimension is None: + raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata") - embedding_dimension = default_model.metadata.get("embedding_dimension") - if embedding_dimension is None: - raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata") + try: + int(embedding_dimension) + except ValueError as err: + raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err - try: - int(embedding_dimension) - except ValueError as err: - raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err + logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})") - logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})") + # Validate default query expansion model + default_query_expansion_model = vector_stores_config.default_query_expansion_model + if default_query_expansion_model is not None: + provider_id = default_query_expansion_model.provider_id + model_id = default_query_expansion_model.model_id + query_model_id = f"{provider_id}/{model_id}" + + if Api.models not in impls: + raise ValueError( + f"Models API is not available but vector_stores config requires query expansion model '{query_model_id}'" + ) + + models_impl = impls[Api.models] + response = await models_impl.list_models() + llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"} + + query_expansion_model = llm_models_list.get(query_model_id) + if query_expansion_model is None: + raise ValueError( + f"Query expansion model '{query_model_id}' not found. Available LLM models: {list(llm_models_list.keys())}" + ) + + logger.debug(f"Validated default query expansion model: {query_model_id}") async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]): @@ -437,6 +464,12 @@ class Stack: await refresh_registry_once(impls) await validate_vector_stores_config(self.run_config.vector_stores, impls) await validate_safety_config(self.run_config.safety, impls) + + # Set global query expansion configuration from stack config + from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config + + set_default_query_expansion_config(self.run_config.vector_stores) + self.impls = impls def create_registry_refresh_task(self): diff --git a/src/llama_stack/distributions/ci-tests/run-with-postgres-store.yaml b/src/llama_stack/distributions/ci-tests/run-with-postgres-store.yaml index 8110dbdf6..219ffdce3 100644 --- a/src/llama_stack/distributions/ci-tests/run-with-postgres-store.yaml +++ b/src/llama_stack/distributions/ci-tests/run-with-postgres-store.yaml @@ -296,5 +296,7 @@ vector_stores: Improved query:' + query_expansion_max_tokens: 100 + query_expansion_temperature: 0.3 safety: default_shield_id: llama-guard diff --git a/src/llama_stack/distributions/ci-tests/run.yaml b/src/llama_stack/distributions/ci-tests/run.yaml index 809b0ef1c..e352e9268 100644 --- a/src/llama_stack/distributions/ci-tests/run.yaml +++ b/src/llama_stack/distributions/ci-tests/run.yaml @@ -287,5 +287,7 @@ vector_stores: Improved query:' + query_expansion_max_tokens: 100 + query_expansion_temperature: 0.3 safety: default_shield_id: llama-guard diff --git a/src/llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml b/src/llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml index ca47d7f4c..e81febb0e 100644 --- a/src/llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +++ b/src/llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml @@ -299,5 +299,7 @@ vector_stores: Improved query:' + query_expansion_max_tokens: 100 + query_expansion_temperature: 0.3 safety: default_shield_id: llama-guard diff --git a/src/llama_stack/distributions/starter-gpu/run.yaml b/src/llama_stack/distributions/starter-gpu/run.yaml index 15555c262..edae6f66d 100644 --- a/src/llama_stack/distributions/starter-gpu/run.yaml +++ b/src/llama_stack/distributions/starter-gpu/run.yaml @@ -290,5 +290,7 @@ vector_stores: Improved query:' + query_expansion_max_tokens: 100 + query_expansion_temperature: 0.3 safety: default_shield_id: llama-guard diff --git a/src/llama_stack/distributions/starter/run-with-postgres-store.yaml b/src/llama_stack/distributions/starter/run-with-postgres-store.yaml index 423b30452..9ed74d96d 100644 --- a/src/llama_stack/distributions/starter/run-with-postgres-store.yaml +++ b/src/llama_stack/distributions/starter/run-with-postgres-store.yaml @@ -296,5 +296,7 @@ vector_stores: Improved query:' + query_expansion_max_tokens: 100 + query_expansion_temperature: 0.3 safety: default_shield_id: llama-guard diff --git a/src/llama_stack/distributions/starter/run.yaml b/src/llama_stack/distributions/starter/run.yaml index a0f56fc42..73679a152 100644 --- a/src/llama_stack/distributions/starter/run.yaml +++ b/src/llama_stack/distributions/starter/run.yaml @@ -287,5 +287,7 @@ vector_stores: Improved query:' + query_expansion_max_tokens: 100 + query_expansion_temperature: 0.3 safety: default_shield_id: llama-guard diff --git a/src/llama_stack/providers/inline/vector_io/faiss/__init__.py b/src/llama_stack/providers/inline/vector_io/faiss/__init__.py index 1b9dcda76..b834589e3 100644 --- a/src/llama_stack/providers/inline/vector_io/faiss/__init__.py +++ b/src/llama_stack/providers/inline/vector_io/faiss/__init__.py @@ -6,19 +6,16 @@ from typing import Any -from llama_stack.core.datatypes import VectorStoresConfig from llama_stack_api import Api from .config import FaissVectorIOConfig -async def get_provider_impl( - config: FaissVectorIOConfig, deps: dict[Api, Any], vector_stores_config: VectorStoresConfig | None = None -): +async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]): from .faiss import FaissVectorIOAdapter assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}" - impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config) + impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/src/llama_stack/providers/inline/vector_io/faiss/faiss.py b/src/llama_stack/providers/inline/vector_io/faiss/faiss.py index ec8afd388..e2aab1a25 100644 --- a/src/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/src/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -14,7 +14,6 @@ import faiss # type: ignore[import-untyped] import numpy as np from numpy.typing import NDArray -from llama_stack.core.datatypes import VectorStoresConfig from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.log import get_logger from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin @@ -190,12 +189,10 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None, - vector_stores_config: VectorStoresConfig | None = None, ) -> None: super().__init__(files_api=files_api, kvstore=None) self.config = config self.inference_api = inference_api - self.vector_stores_config = vector_stores_config self.cache: dict[str, VectorStoreWithIndex] = {} async def initialize(self) -> None: @@ -211,7 +208,6 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco vector_store, await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier), self.inference_api, - self.vector_stores_config, ) self.cache[vector_store.identifier] = index @@ -250,7 +246,6 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco vector_store=vector_store, index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier), inference_api=self.inference_api, - vector_stores_config=self.vector_stores_config, ) async def list_vector_stores(self) -> list[VectorStore]: @@ -284,7 +279,6 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco vector_store=vector_store, index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier), inference_api=self.inference_api, - vector_stores_config=self.vector_stores_config, ) self.cache[vector_store_id] = index return index diff --git a/src/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py b/src/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py index 53e2ad135..e84c299dc 100644 --- a/src/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py +++ b/src/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py @@ -6,18 +6,15 @@ from typing import Any -from llama_stack.core.datatypes import VectorStoresConfig from llama_stack_api import Api from .config import SQLiteVectorIOConfig -async def get_provider_impl( - config: SQLiteVectorIOConfig, deps: dict[Api, Any], vector_stores_config: VectorStoresConfig | None = None -): +async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]): from .sqlite_vec import SQLiteVecVectorIOAdapter assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}" - impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config) + impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index b38ce205e..bc6226c84 100644 --- a/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -14,7 +14,6 @@ import numpy as np import sqlite_vec # type: ignore[import-untyped] from numpy.typing import NDArray -from llama_stack.core.datatypes import VectorStoresConfig from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.log import get_logger from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin @@ -391,12 +390,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro config, inference_api: Inference, files_api: Files | None, - vector_stores_config: VectorStoresConfig | None = None, ) -> None: super().__init__(files_api=files_api, kvstore=None) self.config = config self.inference_api = inference_api - self.vector_stores_config = vector_stores_config self.cache: dict[str, VectorStoreWithIndex] = {} self.vector_store_table = None @@ -411,9 +408,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro index = await SQLiteVecIndex.create( vector_store.embedding_dimension, self.config.db_path, vector_store.identifier ) - self.cache[vector_store.identifier] = VectorStoreWithIndex( - vector_store, index, self.inference_api, self.vector_stores_config - ) + self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api) # Load existing OpenAI vector stores into the in-memory cache await self.initialize_openai_vector_stores() @@ -437,9 +432,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro index = await SQLiteVecIndex.create( vector_store.embedding_dimension, self.config.db_path, vector_store.identifier ) - self.cache[vector_store.identifier] = VectorStoreWithIndex( - vector_store, index, self.inference_api, self.vector_stores_config - ) + self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api) async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None: if vector_store_id in self.cache: @@ -464,7 +457,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro kvstore=self.kvstore, ), inference_api=self.inference_api, - vector_stores_config=self.vector_stores_config, ) self.cache[vector_store_id] = index return index diff --git a/src/llama_stack/providers/remote/vector_io/chroma/__init__.py b/src/llama_stack/providers/remote/vector_io/chroma/__init__.py index 3bce41c36..d774ea643 100644 --- a/src/llama_stack/providers/remote/vector_io/chroma/__init__.py +++ b/src/llama_stack/providers/remote/vector_io/chroma/__init__.py @@ -4,17 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.core.datatypes import VectorStoresConfig from llama_stack_api import Api, ProviderSpec from .config import ChromaVectorIOConfig -async def get_adapter_impl( - config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None -): +async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]): from .chroma import ChromaVectorIOAdapter - impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config) + impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/src/llama_stack/providers/remote/vector_io/chroma/chroma.py b/src/llama_stack/providers/remote/vector_io/chroma/chroma.py index d214dff3a..491db6d4d 100644 --- a/src/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/src/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -11,7 +11,6 @@ from urllib.parse import urlparse import chromadb from numpy.typing import NDArray -from llama_stack.core.datatypes import VectorStoresConfig from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.log import get_logger from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig @@ -126,13 +125,11 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig, inference_api: Inference, files_api: Files | None, - vector_stores_config: VectorStoresConfig | None = None, ) -> None: super().__init__(files_api=files_api, kvstore=None) log.info(f"Initializing ChromaVectorIOAdapter with url: {config}") self.config = config self.inference_api = inference_api - self.vector_stores_config = vector_stores_config self.client = None self.cache = {} self.vector_store_table = None @@ -165,7 +162,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc ) ) self.cache[vector_store.identifier] = VectorStoreWithIndex( - vector_store, ChromaIndex(self.client, collection), self.inference_api, self.vector_stores_config + vector_store, ChromaIndex(self.client, collection), self.inference_api ) async def unregister_vector_store(self, vector_store_id: str) -> None: @@ -210,9 +207,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc collection = await maybe_await(self.client.get_collection(vector_store_id)) if not collection: raise ValueError(f"Vector DB {vector_store_id} not found in Chroma") - index = VectorStoreWithIndex( - vector_store, ChromaIndex(self.client, collection), self.inference_api, self.vector_stores_config - ) + index = VectorStoreWithIndex(vector_store, ChromaIndex(self.client, collection), self.inference_api) self.cache[vector_store_id] = index return index diff --git a/src/llama_stack/providers/remote/vector_io/milvus/__init__.py b/src/llama_stack/providers/remote/vector_io/milvus/__init__.py index b73cf9b3e..1b703d486 100644 --- a/src/llama_stack/providers/remote/vector_io/milvus/__init__.py +++ b/src/llama_stack/providers/remote/vector_io/milvus/__init__.py @@ -4,18 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.core.datatypes import VectorStoresConfig from llama_stack_api import Api, ProviderSpec from .config import MilvusVectorIOConfig -async def get_adapter_impl( - config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None -): +async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec]): from .milvus import MilvusVectorIOAdapter assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}" - impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config) + impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/src/llama_stack/providers/remote/vector_io/milvus/milvus.py b/src/llama_stack/providers/remote/vector_io/milvus/milvus.py index 3b21f3278..044d678fa 100644 --- a/src/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/src/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -11,7 +11,6 @@ from typing import Any from numpy.typing import NDArray from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker -from llama_stack.core.datatypes import VectorStoresConfig from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.log import get_logger from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig @@ -273,14 +272,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig, inference_api: Inference, files_api: Files | None, - vector_stores_config: VectorStoresConfig | None = None, ) -> None: super().__init__(files_api=files_api, kvstore=None) self.config = config self.cache = {} self.client = None self.inference_api = inference_api - self.vector_stores_config = vector_stores_config self.vector_store_table = None self.metadata_collection_name = "openai_vector_stores_metadata" @@ -301,7 +298,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc kvstore=self.kvstore, ), inference_api=self.inference_api, - vector_stores_config=self.vector_stores_config, ) self.cache[vector_store.identifier] = index if isinstance(self.config, RemoteMilvusVectorIOConfig): @@ -329,7 +325,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc vector_store=vector_store, index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level), inference_api=self.inference_api, - vector_stores_config=self.vector_stores_config, ) self.cache[vector_store.identifier] = index @@ -352,7 +347,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc vector_store=vector_store, index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore), inference_api=self.inference_api, - vector_stores_config=self.vector_stores_config, ) self.cache[vector_store_id] = index return index diff --git a/src/llama_stack/providers/remote/vector_io/pgvector/__init__.py b/src/llama_stack/providers/remote/vector_io/pgvector/__init__.py index 002caf4b6..ea0139815 100644 --- a/src/llama_stack/providers/remote/vector_io/pgvector/__init__.py +++ b/src/llama_stack/providers/remote/vector_io/pgvector/__init__.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.core.datatypes import VectorStoresConfig from llama_stack_api import Api, ProviderSpec from .config import PGVectorVectorIOConfig @@ -13,10 +12,9 @@ from .config import PGVectorVectorIOConfig async def get_adapter_impl( config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec], - vector_stores_config: VectorStoresConfig | None = None, ): from .pgvector import PGVectorVectorIOAdapter - impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config) + impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/src/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/src/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index 45a38e52a..fe1b8ce35 100644 --- a/src/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/src/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -13,7 +13,6 @@ from psycopg2 import sql from psycopg2.extras import Json, execute_values from pydantic import BaseModel, TypeAdapter -from llama_stack.core.datatypes import VectorStoresConfig from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.log import get_logger from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str @@ -335,12 +334,10 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None, - vector_stores_config: VectorStoresConfig | None = None, ) -> None: super().__init__(files_api=files_api, kvstore=None) self.config = config self.inference_api = inference_api - self.vector_stores_config = vector_stores_config self.conn = None self.cache = {} self.vector_store_table = None @@ -396,7 +393,6 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt vector_store, index=pgvector_index, inference_api=self.inference_api, - vector_stores_config=self.vector_stores_config, ) self.cache[vector_store.identifier] = index @@ -428,7 +424,6 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt vector_store, index=pgvector_index, inference_api=self.inference_api, - vector_stores_config=self.vector_stores_config, ) self.cache[vector_store.identifier] = index @@ -469,9 +464,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt vector_store = VectorStore.model_validate_json(vector_store_data) index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn) await index.initialize() - self.cache[vector_store_id] = VectorStoreWithIndex( - vector_store, index, self.inference_api, self.vector_stores_config - ) + self.cache[vector_store_id] = VectorStoreWithIndex(vector_store, index, self.inference_api) return self.cache[vector_store_id] async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: diff --git a/src/llama_stack/providers/remote/vector_io/qdrant/__init__.py b/src/llama_stack/providers/remote/vector_io/qdrant/__init__.py index 76e167b75..b5b02fe59 100644 --- a/src/llama_stack/providers/remote/vector_io/qdrant/__init__.py +++ b/src/llama_stack/providers/remote/vector_io/qdrant/__init__.py @@ -4,17 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.core.datatypes import VectorStoresConfig from llama_stack_api import Api, ProviderSpec from .config import QdrantVectorIOConfig -async def get_adapter_impl( - config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None -): +async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]): from .qdrant import QdrantVectorIOAdapter - impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config) + impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/src/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/src/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 2de71f7cc..dc6546646 100644 --- a/src/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/src/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -13,7 +13,6 @@ from numpy.typing import NDArray from qdrant_client import AsyncQdrantClient, models from qdrant_client.models import PointStruct -from llama_stack.core.datatypes import VectorStoresConfig from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.log import get_logger from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig @@ -153,14 +152,12 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig, inference_api: Inference, files_api: Files | None = None, - vector_stores_config: VectorStoresConfig | None = None, ) -> None: super().__init__(files_api=files_api, kvstore=None) self.config = config self.client: AsyncQdrantClient = None self.cache = {} self.inference_api = inference_api - self.vector_stores_config = vector_stores_config self.vector_store_table = None self._qdrant_lock = asyncio.Lock() @@ -179,7 +176,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc vector_store, QdrantIndex(self.client, vector_store.identifier), self.inference_api, - self.vector_stores_config, ) self.cache[vector_store.identifier] = index self.openai_vector_stores = await self._load_openai_vector_stores() @@ -199,7 +195,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc vector_store=vector_store, index=QdrantIndex(self.client, vector_store.identifier), inference_api=self.inference_api, - vector_stores_config=self.vector_stores_config, ) self.cache[vector_store.identifier] = index @@ -231,7 +226,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc vector_store=vector_store, index=QdrantIndex(client=self.client, collection_name=vector_store.identifier), inference_api=self.inference_api, - vector_stores_config=self.vector_stores_config, ) self.cache[vector_store_id] = index return index diff --git a/src/llama_stack/providers/remote/vector_io/weaviate/__init__.py b/src/llama_stack/providers/remote/vector_io/weaviate/__init__.py index 77bf357f4..a13cca8a1 100644 --- a/src/llama_stack/providers/remote/vector_io/weaviate/__init__.py +++ b/src/llama_stack/providers/remote/vector_io/weaviate/__init__.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.core.datatypes import VectorStoresConfig from llama_stack_api import Api, ProviderSpec from .config import WeaviateVectorIOConfig @@ -13,10 +12,9 @@ from .config import WeaviateVectorIOConfig async def get_adapter_impl( config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec], - vector_stores_config: VectorStoresConfig | None = None, ): from .weaviate import WeaviateVectorIOAdapter - impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config) + impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/src/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/src/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 1c52fa84c..67ec523d7 100644 --- a/src/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/src/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -12,7 +12,6 @@ from numpy.typing import NDArray from weaviate.classes.init import Auth from weaviate.classes.query import Filter, HybridFusion -from llama_stack.core.datatypes import VectorStoresConfig from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.log import get_logger @@ -268,12 +267,10 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None, - vector_stores_config: VectorStoresConfig | None = None, ) -> None: super().__init__(files_api=files_api, kvstore=None) self.config = config self.inference_api = inference_api - self.vector_stores_config = vector_stores_config self.client_cache = {} self.cache = {} self.vector_store_table = None @@ -319,7 +316,6 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv vector_store=vector_store, index=idx, inference_api=self.inference_api, - vector_stores_config=self.vector_stores_config, ) # Load OpenAI vector stores metadata into cache @@ -348,7 +344,6 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv vector_store, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api, - self.vector_stores_config, ) async def unregister_vector_store(self, vector_store_id: str) -> None: @@ -383,7 +378,6 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv vector_store=vector_store, index=WeaviateIndex(client=client, collection_name=vector_store.identifier), inference_api=self.inference_api, - vector_stores_config=self.vector_stores_config, ) self.cache[vector_store_id] = index return index diff --git a/src/llama_stack/providers/utils/memory/__init__.py b/src/llama_stack/providers/utils/memory/__init__.py index 756f351d8..5e0942402 100644 --- a/src/llama_stack/providers/utils/memory/__init__.py +++ b/src/llama_stack/providers/utils/memory/__init__.py @@ -3,3 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +from .constants import DEFAULT_QUERY_EXPANSION_PROMPT + +__all__ = ["DEFAULT_QUERY_EXPANSION_PROMPT"] diff --git a/src/llama_stack/providers/utils/memory/constants.py b/src/llama_stack/providers/utils/memory/constants.py new file mode 100644 index 000000000..d8703bbce --- /dev/null +++ b/src/llama_stack/providers/utils/memory/constants.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Default prompt template for query expansion in vector search +DEFAULT_QUERY_EXPANSION_PROMPT = "Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:" diff --git a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 4e67cf24b..e0293507d 100644 --- a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -379,11 +379,6 @@ class OpenAIVectorStoreMixin(ABC): f"Using embedding config from extra_body: model='{embedding_model}', dimension={embedding_dimension}" ) - # Extract query expansion model from extra_body if provided - query_expansion_model = extra_body.get("query_expansion_model") - if query_expansion_model: - logger.debug(f"Using per-store query expansion model: {query_expansion_model}") - # use provider_id set by router; fallback to provider's own ID when used directly via --stack-config provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None) # Derive the canonical vector_store_id (allow override, else generate) @@ -407,7 +402,6 @@ class OpenAIVectorStoreMixin(ABC): provider_id=provider_id, provider_resource_id=vector_store_id, vector_store_name=params.name, - query_expansion_model=query_expansion_model, ) await self.register_vector_store(vector_store) @@ -621,9 +615,6 @@ class OpenAIVectorStoreMixin(ABC): "rewrite_query": rewrite_query, } - # Add vector_stores_config if available (for query rewriting) - if hasattr(self, "vector_stores_config"): - params["vector_stores_config"] = self.vector_stores_config # TODO: Add support for ranking_options.ranker response = await self.query_chunks( diff --git a/src/llama_stack/providers/utils/memory/query_expansion_config.py b/src/llama_stack/providers/utils/memory/query_expansion_config.py new file mode 100644 index 000000000..0b51c1a9a --- /dev/null +++ b/src/llama_stack/providers/utils/memory/query_expansion_config.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig +from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT + +# Global configuration for query expansion - set during stack startup +_DEFAULT_QUERY_EXPANSION_MODEL: QualifiedModel | None = None +_DEFAULT_QUERY_EXPANSION_MAX_TOKENS: int = 100 +_DEFAULT_QUERY_EXPANSION_TEMPERATURE: float = 0.3 +_QUERY_EXPANSION_PROMPT_OVERRIDE: str | None = None + + +def set_default_query_expansion_config(vector_stores_config: VectorStoresConfig | None): + """Set the global default query expansion configuration from stack config.""" + global \ + _DEFAULT_QUERY_EXPANSION_MODEL, \ + _QUERY_EXPANSION_PROMPT_OVERRIDE, \ + _DEFAULT_QUERY_EXPANSION_MAX_TOKENS, \ + _DEFAULT_QUERY_EXPANSION_TEMPERATURE + if vector_stores_config: + _DEFAULT_QUERY_EXPANSION_MODEL = vector_stores_config.default_query_expansion_model + # Only set override if user provided a custom prompt different from default + if vector_stores_config.query_expansion_prompt != DEFAULT_QUERY_EXPANSION_PROMPT: + _QUERY_EXPANSION_PROMPT_OVERRIDE = vector_stores_config.query_expansion_prompt + else: + _QUERY_EXPANSION_PROMPT_OVERRIDE = None + _DEFAULT_QUERY_EXPANSION_MAX_TOKENS = vector_stores_config.query_expansion_max_tokens + _DEFAULT_QUERY_EXPANSION_TEMPERATURE = vector_stores_config.query_expansion_temperature + else: + _DEFAULT_QUERY_EXPANSION_MODEL = None + _QUERY_EXPANSION_PROMPT_OVERRIDE = None + _DEFAULT_QUERY_EXPANSION_MAX_TOKENS = 100 + _DEFAULT_QUERY_EXPANSION_TEMPERATURE = 0.3 diff --git a/src/llama_stack/providers/utils/memory/vector_store.py b/src/llama_stack/providers/utils/memory/vector_store.py index 71d61787a..61fa996e4 100644 --- a/src/llama_stack/providers/utils/memory/vector_store.py +++ b/src/llama_stack/providers/utils/memory/vector_store.py @@ -17,7 +17,6 @@ import numpy as np from numpy.typing import NDArray from pydantic import BaseModel -from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig from llama_stack.log import get_logger from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -30,19 +29,18 @@ from llama_stack_api import ( Chunk, ChunkMetadata, InterleavedContent, + OpenAIChatCompletionRequestWithExtraBody, OpenAIEmbeddingsRequestWithExtraBody, QueryChunksResponse, RAGDocument, VectorStore, ) -from llama_stack_api.inference import ( - OpenAIChatCompletionRequestWithExtraBody, - OpenAIUserMessageParam, -) -from llama_stack_api.models import ModelType log = get_logger(name=__name__, category="providers::utils") +from llama_stack.providers.utils.memory import query_expansion_config +from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT + class ChunkForDeletion(BaseModel): """Information needed to delete a chunk from a vector store. @@ -268,7 +266,6 @@ class VectorStoreWithIndex: vector_store: VectorStore index: EmbeddingIndex inference_api: Api.inference - vector_stores_config: VectorStoresConfig | None = None async def insert_chunks( self, @@ -296,6 +293,39 @@ class VectorStoreWithIndex: embeddings = np.array([c.embedding for c in chunks], dtype=np.float32) await self.index.add_chunks(chunks, embeddings) + async def _rewrite_query_for_file_search(self, query: str) -> str: + """Rewrite a search query using the globally configured LLM model for better retrieval results.""" + if not query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL: + log.debug("No default query expansion model configured, using original query") + return query + + model_id = f"{query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL.provider_id}/{query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL.model_id}" + + # Use custom prompt from config if provided, otherwise use built-in default + # Users only need to configure the model - prompt is automatic with optional override + if query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE: + # Custom prompt from config - format if it contains {query} placeholder + prompt = ( + query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE.format(query=query) + if "{query}" in query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE + else query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE + ) + else: + # Use built-in default prompt and format with query + prompt = DEFAULT_QUERY_EXPANSION_PROMPT.format(query=query) + + request = OpenAIChatCompletionRequestWithExtraBody( + model=model_id, + messages=[{"role": "user", "content": prompt}], + max_tokens=query_expansion_config._DEFAULT_QUERY_EXPANSION_MAX_TOKENS, + temperature=query_expansion_config._DEFAULT_QUERY_EXPANSION_TEMPERATURE, + ) + + response = await self.inference_api.openai_chat_completion(request) + rewritten_query = response.choices[0].message.content.strip() + log.debug(f"Query rewritten: '{query}' → '{rewritten_query}'") + return rewritten_query + async def query_chunks( self, query: InterleavedContent, @@ -304,10 +334,6 @@ class VectorStoreWithIndex: if params is None: params = {} - # Extract configuration if provided by router - if "vector_stores_config" in params: - self.vector_stores_config = params["vector_stores_config"] - k = params.get("max_chunks", 3) mode = params.get("mode") score_threshold = params.get("score_threshold", 0.0) @@ -331,18 +357,9 @@ class VectorStoreWithIndex: query_string = interleaved_content_as_str(query) - # Apply query rewriting if enabled + # Apply query rewriting if enabled and model is configured if params.get("rewrite_query", False): - if self.vector_stores_config: - log.debug(f"VectorStoreWithIndex received config: {self.vector_stores_config}") - if hasattr(self.vector_stores_config, "default_query_expansion_model"): - log.debug( - f"Config has default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}" - ) - else: - log.debug("No vector_stores_config found - cannot perform query rewriting") - - query_string = await self._rewrite_query_for_search(query_string) + query_string = await self._rewrite_query_for_file_search(query_string) if mode == "keyword": return await self.index.query_keyword(query_string, k, score_threshold) @@ -359,88 +376,3 @@ class VectorStoreWithIndex: ) else: return await self.index.query_vector(query_vector, k, score_threshold) - - async def _rewrite_query_for_search(self, query: str) -> str: - """Rewrite the user query to improve vector search performance. - - :param query: The original user query - :returns: The rewritten query optimized for vector search - """ - expansion_model = None - - # Check for per-store query expansion model first - if self.vector_store.query_expansion_model: - # Parse the model string into provider_id and model_id - model_parts = self.vector_store.query_expansion_model.split("/", 1) - if len(model_parts) == 2: - expansion_model = QualifiedModel(provider_id=model_parts[0], model_id=model_parts[1]) - log.debug(f"Using per-store query expansion model: {expansion_model}") - else: - log.warning( - f"Invalid query_expansion_model format: {self.vector_store.query_expansion_model}. Expected 'provider_id/model_id'" - ) - - # Fall back to global default if no per-store model - if not expansion_model: - if not self.vector_stores_config: - raise ValueError( - f"No vector_stores_config found and no per-store query_expansion_model! self.vector_stores_config is: {self.vector_stores_config}" - ) - if not self.vector_stores_config.default_query_expansion_model: - raise ValueError( - f"No default_query_expansion_model configured and no per-store query_expansion_model! vector_stores_config: {self.vector_stores_config}, default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}" - ) - expansion_model = self.vector_stores_config.default_query_expansion_model - log.debug(f"Using global default query expansion model: {expansion_model}") - - chat_model = f"{expansion_model.provider_id}/{expansion_model.model_id}" - - # Validate that the model is available and is an LLM - try: - models_response = await self.inference_api.routing_table.list_models() - except Exception as e: - raise RuntimeError(f"Failed to list available models for validation: {e}") from e - - model_found = False - for model in models_response.data: - if model.identifier == chat_model: - if model.model_type != ModelType.llm: - raise ValueError( - f"Configured query expansion model '{chat_model}' is not an LLM model " - f"(found type: {model.model_type}). Query rewriting requires an LLM model." - ) - model_found = True - break - - if not model_found: - available_llm_models = [m.identifier for m in models_response.data if m.model_type == ModelType.llm] - raise ValueError( - f"Configured query expansion model '{chat_model}' is not available. " - f"Available LLM models: {available_llm_models}" - ) - - # Use the configured prompt (has a default value) - rewrite_prompt = self.vector_stores_config.query_expansion_prompt.format(query=query) - - chat_request = OpenAIChatCompletionRequestWithExtraBody( - model=chat_model, - messages=[ - OpenAIUserMessageParam( - role="user", - content=rewrite_prompt, - ) - ], - max_tokens=100, - ) - - try: - response = await self.inference_api.openai_chat_completion(chat_request) - except Exception as e: - raise RuntimeError(f"Failed to generate rewritten query: {e}") from e - - if response.choices and len(response.choices) > 0: - rewritten_query = response.choices[0].message.content.strip() - log.info(f"Query rewritten: '{query}' → '{rewritten_query}'") - return rewritten_query - else: - raise RuntimeError("No response received from LLM model for query rewriting") diff --git a/src/llama_stack_api/vector_stores.py b/src/llama_stack_api/vector_stores.py index 4c0d1ced2..0a1e6c53c 100644 --- a/src/llama_stack_api/vector_stores.py +++ b/src/llama_stack_api/vector_stores.py @@ -25,7 +25,6 @@ class VectorStore(Resource): embedding_model: str embedding_dimension: int vector_store_name: str | None = None - query_expansion_model: str | None = None @property def vector_store_id(self) -> str: diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index cfda7aa5e..83bf22f34 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -1233,94 +1233,122 @@ async def test_embedding_config_required_model_missing(vector_io_adapter): async def test_query_expansion_functionality(vector_io_adapter): - """Test query expansion with per-store models, global defaults, and error validation.""" + """Test query expansion with simplified global configuration approach.""" from unittest.mock import MagicMock - from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig + from llama_stack.core.datatypes import QualifiedModel + from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT + from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex - from llama_stack_api.models import Model, ModelType + from llama_stack_api import QueryChunksResponse - vector_io_adapter.register_vector_store = AsyncMock() - vector_io_adapter.__provider_id__ = "test_provider" - - # Test 1: Per-store model usage - params = OpenAICreateVectorStoreRequestWithExtraBody( - name="test_store", - metadata={}, - **{"embedding_model": "test/embedding", "query_expansion_model": "test/llama-model"}, - ) - await vector_io_adapter.openai_create_vector_store(params) - call_args = vector_io_adapter.register_vector_store.call_args[0][0] - assert call_args.query_expansion_model == "test/llama-model" - - # Test 2: Global default fallback - vector_io_adapter.register_vector_store.reset_mock() - params_no_model = OpenAICreateVectorStoreRequestWithExtraBody( - name="test_store2", metadata={}, **{"embedding_model": "test/embedding"} - ) - await vector_io_adapter.openai_create_vector_store(params_no_model) - call_args2 = vector_io_adapter.register_vector_store.call_args[0][0] - assert call_args2.query_expansion_model is None - - # Test query rewriting scenarios - mock_inference_api = MagicMock() - - # Per-store model scenario + # Mock a simple vector store and index mock_vector_store = MagicMock() - mock_vector_store.query_expansion_model = "test/llama-model" - mock_inference_api.routing_table.list_models = AsyncMock( - return_value=MagicMock( - data=[Model(identifier="test/llama-model", provider_id="test", model_type=ModelType.llm)] - ) - ) - mock_inference_api.openai_chat_completion = AsyncMock( - return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="per-store expanded"))]) - ) + mock_vector_store.embedding_model = "test/embedding" + mock_inference_api = MagicMock() + mock_index = MagicMock() + # Create VectorStoreWithIndex with simplified constructor vector_store_with_index = VectorStoreWithIndex( vector_store=mock_vector_store, - index=MagicMock(), + index=mock_index, inference_api=mock_inference_api, - vector_stores_config=VectorStoresConfig( - default_query_expansion_model=QualifiedModel(provider_id="global", model_id="default") - ), ) - result = await vector_store_with_index._rewrite_query_for_search("test") - assert mock_inference_api.openai_chat_completion.call_args[0][0].model == "test/llama-model" - assert result == "per-store expanded" + # Mock the query_vector method to return a simple response + mock_response = QueryChunksResponse(chunks=[], scores=[]) + mock_index.query_vector = AsyncMock(return_value=mock_response) - # Global default fallback scenario - mock_inference_api.reset_mock() - mock_vector_store.query_expansion_model = None - mock_inference_api.routing_table.list_models = AsyncMock( - return_value=MagicMock( - data=[Model(identifier="global/default", provider_id="global", model_type=ModelType.llm)] - ) + # Mock embeddings generation + mock_inference_api.openai_embeddings = AsyncMock( + return_value=MagicMock(data=[MagicMock(embedding=[0.1, 0.2, 0.3])]) ) + + # Test 1: Query expansion with default prompt (no custom prompt configured) + mock_vector_stores_config = MagicMock() + mock_vector_stores_config.default_query_expansion_model = QualifiedModel(provider_id="test", model_id="llama") + mock_vector_stores_config.query_expansion_prompt = None # Use built-in default prompt + mock_vector_stores_config.query_expansion_max_tokens = 100 # Default value + mock_vector_stores_config.query_expansion_temperature = 0.3 # Default value + + # Set global config + set_default_query_expansion_config(mock_vector_stores_config) + + # Mock chat completion for query rewriting mock_inference_api.openai_chat_completion = AsyncMock( - return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="global expanded"))]) + return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="expanded test query"))]) ) - result = await vector_store_with_index._rewrite_query_for_search("test") - assert mock_inference_api.openai_chat_completion.call_args[0][0].model == "global/default" - assert result == "global expanded" + params = {"rewrite_query": True, "max_chunks": 5} + result = await vector_store_with_index.query_chunks("test query", params) - # Test 3: Error cases - # Model not found - mock_vector_store.query_expansion_model = "missing/model" - mock_inference_api.routing_table.list_models = AsyncMock(return_value=MagicMock(data=[])) + # Verify chat completion was called for query rewriting + mock_inference_api.openai_chat_completion.assert_called_once() + chat_call_args = mock_inference_api.openai_chat_completion.call_args[0][0] + assert chat_call_args.model == "test/llama" - with pytest.raises(ValueError, match="Configured query expansion model .* is not available"): - await vector_store_with_index._rewrite_query_for_search("test") + # Verify default prompt is used (contains our built-in prompt text) + prompt_text = chat_call_args.messages[0].content + expected_prompt = DEFAULT_QUERY_EXPANSION_PROMPT.format(query="test query") + assert prompt_text == expected_prompt - # Non-LLM model - mock_vector_store.query_expansion_model = "test/embedding-model" - mock_inference_api.routing_table.list_models = AsyncMock( - return_value=MagicMock( - data=[Model(identifier="test/embedding-model", provider_id="test", model_type=ModelType.embedding)] - ) - ) + # Verify default inference parameters are used + assert chat_call_args.max_tokens == 100 # Default value + assert chat_call_args.temperature == 0.3 # Default value - with pytest.raises(ValueError, match="is not an LLM model.*Query rewriting requires an LLM model"): - await vector_store_with_index._rewrite_query_for_search("test") + # Verify the rest of the flow proceeded normally + mock_inference_api.openai_embeddings.assert_called_once() + mock_index.query_vector.assert_called_once() + assert result == mock_response + + # Test 1b: Query expansion with custom prompt override and inference parameters + mock_inference_api.reset_mock() + mock_index.reset_mock() + + mock_vector_stores_config.query_expansion_prompt = "Custom prompt for rewriting: {query}" + mock_vector_stores_config.query_expansion_max_tokens = 150 + mock_vector_stores_config.query_expansion_temperature = 0.7 + set_default_query_expansion_config(mock_vector_stores_config) + + result = await vector_store_with_index.query_chunks("test query", params) + + # Verify custom prompt and parameters are used + mock_inference_api.openai_chat_completion.assert_called_once() + chat_call_args = mock_inference_api.openai_chat_completion.call_args[0][0] + prompt_text = chat_call_args.messages[0].content + assert prompt_text == "Custom prompt for rewriting: test query" + assert "Expand this query with relevant synonyms" not in prompt_text # Default not used + + # Verify custom inference parameters + assert chat_call_args.max_tokens == 150 + assert chat_call_args.temperature == 0.7 + + # Test 2: No query expansion when no global model is configured + mock_inference_api.reset_mock() + mock_index.reset_mock() + + # Clear global config + set_default_query_expansion_config(None) + + params = {"rewrite_query": True, "max_chunks": 5} + result2 = await vector_store_with_index.query_chunks("test query", params) + + # Verify chat completion was NOT called + mock_inference_api.openai_chat_completion.assert_not_called() + # But normal flow should still work + mock_inference_api.openai_embeddings.assert_called_once() + mock_index.query_vector.assert_called_once() + assert result2 == mock_response + + # Test 3: Normal behavior without rewrite_query parameter + mock_inference_api.reset_mock() + mock_index.reset_mock() + + params_no_rewrite = {"max_chunks": 5} + result3 = await vector_store_with_index.query_chunks("test query", params_no_rewrite) + + # Neither chat completion nor query rewriting should be called + mock_inference_api.openai_chat_completion.assert_not_called() + mock_inference_api.openai_embeddings.assert_called_once() + mock_index.query_vector.assert_called_once() + assert result3 == mock_response