diff --git a/src/llama_stack/core/datatypes.py b/src/llama_stack/core/datatypes.py index 8fab715f2..00527a1bd 100644 --- a/src/llama_stack/core/datatypes.py +++ b/src/llama_stack/core/datatypes.py @@ -18,7 +18,7 @@ from llama_stack.core.storage.datatypes import ( StorageConfig, ) from llama_stack.log import LoggingConfig -from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT +from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_REWRITE_PROMPT from llama_stack_api import ( Api, Benchmark, @@ -374,7 +374,7 @@ class RewriteQueryParams(BaseModel): description="LLM model for query rewriting/expansion in vector search.", ) prompt: str = Field( - default=DEFAULT_QUERY_EXPANSION_PROMPT, + default=DEFAULT_QUERY_REWRITE_PROMPT, description="Prompt template for query rewriting. Use {query} as placeholder for the original query.", ) max_tokens: int = Field( diff --git a/src/llama_stack/providers/inline/vector_io/faiss/faiss.py b/src/llama_stack/providers/inline/vector_io/faiss/faiss.py index e2aab1a25..91a17058b 100644 --- a/src/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/src/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -184,12 +184,7 @@ class FaissIndex(EmbeddingIndex): class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate): - def __init__( - self, - config: FaissVectorIOConfig, - inference_api: Inference, - files_api: Files | None, - ) -> None: + def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None: super().__init__(files_api=files_api, kvstore=None) self.config = config self.inference_api = inference_api diff --git a/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index bc6226c84..a384a33dc 100644 --- a/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -385,12 +385,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro and creates a cache of VectorStoreWithIndex instances (each wrapping a SQLiteVecIndex). """ - def __init__( - self, - config, - inference_api: Inference, - files_api: Files | None, - ) -> None: + def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None: super().__init__(files_api=files_api, kvstore=None) self.config = config self.inference_api = inference_api diff --git a/src/llama_stack/providers/remote/vector_io/pgvector/__init__.py b/src/llama_stack/providers/remote/vector_io/pgvector/__init__.py index ea0139815..36018fd95 100644 --- a/src/llama_stack/providers/remote/vector_io/pgvector/__init__.py +++ b/src/llama_stack/providers/remote/vector_io/pgvector/__init__.py @@ -9,10 +9,7 @@ from llama_stack_api import Api, ProviderSpec from .config import PGVectorVectorIOConfig -async def get_adapter_impl( - config: PGVectorVectorIOConfig, - deps: dict[Api, ProviderSpec], -): +async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]): from .pgvector import PGVectorVectorIOAdapter impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) diff --git a/src/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/src/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index fe1b8ce35..5c86fb08d 100644 --- a/src/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/src/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -330,10 +330,7 @@ class PGVectorIndex(EmbeddingIndex): class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate): def __init__( - self, - config: PGVectorVectorIOConfig, - inference_api: Inference, - files_api: Files | None = None, + self, config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None ) -> None: super().__init__(files_api=files_api, kvstore=None) self.config = config @@ -389,11 +386,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt kvstore=self.kvstore, ) await pgvector_index.initialize() - index = VectorStoreWithIndex( - vector_store, - index=pgvector_index, - inference_api=self.inference_api, - ) + index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api) self.cache[vector_store.identifier] = index async def shutdown(self) -> None: @@ -420,11 +413,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt vector_store=vector_store, dimension=vector_store.embedding_dimension, conn=self.conn, kvstore=self.kvstore ) await pgvector_index.initialize() - index = VectorStoreWithIndex( - vector_store, - index=pgvector_index, - inference_api=self.inference_api, - ) + index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api) self.cache[vector_store.identifier] = index async def unregister_vector_store(self, vector_store_id: str) -> None: diff --git a/src/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/src/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index dc6546646..4dd78d834 100644 --- a/src/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/src/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -173,9 +173,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc for vector_store_data in stored_vector_stores: vector_store = VectorStore.model_validate_json(vector_store_data) index = VectorStoreWithIndex( - vector_store, - QdrantIndex(self.client, vector_store.identifier), - self.inference_api, + vector_store, QdrantIndex(self.client, vector_store.identifier), self.inference_api ) self.cache[vector_store.identifier] = index self.openai_vector_stores = await self._load_openai_vector_stores() diff --git a/src/llama_stack/providers/remote/vector_io/weaviate/__init__.py b/src/llama_stack/providers/remote/vector_io/weaviate/__init__.py index a13cca8a1..47546d459 100644 --- a/src/llama_stack/providers/remote/vector_io/weaviate/__init__.py +++ b/src/llama_stack/providers/remote/vector_io/weaviate/__init__.py @@ -9,10 +9,7 @@ from llama_stack_api import Api, ProviderSpec from .config import WeaviateVectorIOConfig -async def get_adapter_impl( - config: WeaviateVectorIOConfig, - deps: dict[Api, ProviderSpec], -): +async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]): from .weaviate import WeaviateVectorIOAdapter impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) diff --git a/src/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/src/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 67ec523d7..c15d5f468 100644 --- a/src/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/src/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -262,12 +262,7 @@ class WeaviateIndex(EmbeddingIndex): class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorStoresProtocolPrivate): - def __init__( - self, - config: WeaviateVectorIOConfig, - inference_api: Inference, - files_api: Files | None, - ) -> None: + def __init__(self, config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None: super().__init__(files_api=files_api, kvstore=None) self.config = config self.inference_api = inference_api @@ -313,9 +308,7 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv client = self._get_client() idx = WeaviateIndex(client=client, collection_name=vector_store.identifier, kvstore=self.kvstore) self.cache[vector_store.identifier] = VectorStoreWithIndex( - vector_store=vector_store, - index=idx, - inference_api=self.inference_api, + vector_store=vector_store, index=idx, inference_api=self.inference_api ) # Load OpenAI vector stores metadata into cache @@ -341,9 +334,7 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv ) self.cache[vector_store.identifier] = VectorStoreWithIndex( - vector_store, - WeaviateIndex(client=client, collection_name=sanitized_collection_name), - self.inference_api, + vector_store, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api ) async def unregister_vector_store(self, vector_store_id: str) -> None: diff --git a/src/llama_stack/providers/utils/memory/__init__.py b/src/llama_stack/providers/utils/memory/__init__.py index 5e0942402..05a832b6f 100644 --- a/src/llama_stack/providers/utils/memory/__init__.py +++ b/src/llama_stack/providers/utils/memory/__init__.py @@ -4,6 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .constants import DEFAULT_QUERY_EXPANSION_PROMPT +from .constants import DEFAULT_QUERY_REWRITE_PROMPT -__all__ = ["DEFAULT_QUERY_EXPANSION_PROMPT"] +__all__ = ["DEFAULT_QUERY_REWRITE_PROMPT"] diff --git a/src/llama_stack/providers/utils/memory/constants.py b/src/llama_stack/providers/utils/memory/constants.py index d8703bbce..1f6e2cef6 100644 --- a/src/llama_stack/providers/utils/memory/constants.py +++ b/src/llama_stack/providers/utils/memory/constants.py @@ -4,5 +4,5 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Default prompt template for query expansion in vector search -DEFAULT_QUERY_EXPANSION_PROMPT = "Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:" +# Default prompt template for query rewriting in vector search +DEFAULT_QUERY_REWRITE_PROMPT = "Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:" diff --git a/src/llama_stack/providers/utils/memory/rewrite_query_config.py b/src/llama_stack/providers/utils/memory/rewrite_query_config.py index 9c53638b8..7128116dd 100644 --- a/src/llama_stack/providers/utils/memory/rewrite_query_config.py +++ b/src/llama_stack/providers/utils/memory/rewrite_query_config.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig -from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT +from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_REWRITE_PROMPT # Global configuration for query rewriting - set during stack startup _DEFAULT_REWRITE_QUERY_MODEL: QualifiedModel | None = None @@ -25,7 +25,7 @@ def set_default_rewrite_query_config(vector_stores_config: VectorStoresConfig | params = vector_stores_config.rewrite_query_params _DEFAULT_REWRITE_QUERY_MODEL = params.model # Only set override if user provided a custom prompt different from default - if params.prompt != DEFAULT_QUERY_EXPANSION_PROMPT: + if params.prompt != DEFAULT_QUERY_REWRITE_PROMPT: _REWRITE_QUERY_PROMPT_OVERRIDE = params.prompt else: _REWRITE_QUERY_PROMPT_OVERRIDE = None diff --git a/src/llama_stack/providers/utils/memory/vector_store.py b/src/llama_stack/providers/utils/memory/vector_store.py index 11754bae2..e22075a5f 100644 --- a/src/llama_stack/providers/utils/memory/vector_store.py +++ b/src/llama_stack/providers/utils/memory/vector_store.py @@ -39,7 +39,7 @@ from llama_stack_api import ( log = get_logger(name=__name__, category="providers::utils") from llama_stack.providers.utils.memory import rewrite_query_config -from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT +from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_REWRITE_PROMPT class ChunkForDeletion(BaseModel): @@ -312,7 +312,7 @@ class VectorStoreWithIndex: ) else: # Use built-in default prompt and format with query - prompt = DEFAULT_QUERY_EXPANSION_PROMPT.format(query=query) + prompt = DEFAULT_QUERY_REWRITE_PROMPT.format(query=query) request = OpenAIChatCompletionRequestWithExtraBody( model=model_id, diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index 07ec41bec..4588fe7e5 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -1237,7 +1237,7 @@ async def test_query_expansion_functionality(vector_io_adapter): from unittest.mock import MagicMock from llama_stack.core.datatypes import QualifiedModel, RewriteQueryParams - from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT + from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_REWRITE_PROMPT from llama_stack.providers.utils.memory.rewrite_query_config import set_default_rewrite_query_config from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex from llama_stack_api import QueryChunksResponse @@ -1288,7 +1288,7 @@ async def test_query_expansion_functionality(vector_io_adapter): # Verify default prompt is used (contains our built-in prompt text) prompt_text = chat_call_args.messages[0].content - expected_prompt = DEFAULT_QUERY_EXPANSION_PROMPT.format(query="test query") + expected_prompt = DEFAULT_QUERY_REWRITE_PROMPT.format(query="test query") assert prompt_text == expected_prompt # Verify default inference parameters are used