refactor to only configuration of model at build time

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-11-21 11:27:25 -05:00
parent 2cc7943fd6
commit d887f1f8bb
31 changed files with 280 additions and 315 deletions

View file

@ -18,6 +18,7 @@ from llama_stack.core.storage.datatypes import (
StorageConfig, StorageConfig,
) )
from llama_stack.log import LoggingConfig from llama_stack.log import LoggingConfig
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
from llama_stack_api import ( from llama_stack_api import (
Api, Api,
Benchmark, Benchmark,
@ -381,9 +382,17 @@ class VectorStoresConfig(BaseModel):
description="Default LLM model for query expansion/rewriting in vector search.", description="Default LLM model for query expansion/rewriting in vector search.",
) )
query_expansion_prompt: str = Field( query_expansion_prompt: str = Field(
default="Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:", default=DEFAULT_QUERY_EXPANSION_PROMPT,
description="Prompt template for query expansion. Use {query} as placeholder for the original query.", description="Prompt template for query expansion. Use {query} as placeholder for the original query.",
) )
query_expansion_max_tokens: int = Field(
default=100,
description="Maximum number of tokens for query expansion responses.",
)
query_expansion_temperature: float = Field(
default=0.3,
description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).",
)
class SafetyConfig(BaseModel): class SafetyConfig(BaseModel):

View file

@ -374,13 +374,6 @@ async def instantiate_provider(
method = "get_adapter_impl" method = "get_adapter_impl"
args = [config, deps] args = [config, deps]
# Add vector_stores_config for vector_io providers
if (
"vector_stores_config" in inspect.signature(getattr(module, method)).parameters
and provider_spec.api == Api.vector_io
):
args.append(run_config.vector_stores)
elif isinstance(provider_spec, AutoRoutedProviderSpec): elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl" method = "get_auto_router_impl"
@ -401,11 +394,6 @@ async def instantiate_provider(
args.append(policy) args.append(policy)
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry: if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
args.append(run_config.telemetry.enabled) args.append(run_config.telemetry.enabled)
if (
"vector_stores_config" in inspect.signature(getattr(module, method)).parameters
and provider_spec.api == Api.vector_io
):
args.append(run_config.vector_stores)
fn = getattr(module, method) fn = getattr(module, method)
impl = await fn(*args) impl = await fn(*args)

View file

@ -99,19 +99,6 @@ class VectorIORouter(VectorIO):
) -> QueryChunksResponse: ) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_store_id}") logger.debug(f"VectorIORouter.query_chunks: {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id) provider = await self.routing_table.get_provider_impl(vector_store_id)
# Ensure params dict exists and add vector_stores_config for query rewriting
if params is None:
params = {}
logger.debug(f"Router vector_stores_config: {self.vector_stores_config}")
if self.vector_stores_config and hasattr(self.vector_stores_config, "default_query_expansion_model"):
logger.debug(
f"Router default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
)
params["vector_stores_config"] = self.vector_stores_config
return await provider.query_chunks(vector_store_id, query, params) return await provider.query_chunks(vector_store_id, query, params)
# OpenAI Vector Stores API endpoints # OpenAI Vector Stores API endpoints

View file

@ -144,35 +144,62 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig
if vector_stores_config is None: if vector_stores_config is None:
return return
# Validate default embedding model
default_embedding_model = vector_stores_config.default_embedding_model default_embedding_model = vector_stores_config.default_embedding_model
if default_embedding_model is None: if default_embedding_model is not None:
return provider_id = default_embedding_model.provider_id
model_id = default_embedding_model.model_id
default_model_id = f"{provider_id}/{model_id}"
provider_id = default_embedding_model.provider_id if Api.models not in impls:
model_id = default_embedding_model.model_id raise ValueError(
default_model_id = f"{provider_id}/{model_id}" f"Models API is not available but vector_stores config requires model '{default_model_id}'"
)
if Api.models not in impls: models_impl = impls[Api.models]
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'") response = await models_impl.list_models()
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
models_impl = impls[Api.models] default_model = models_list.get(default_model_id)
response = await models_impl.list_models() if default_model is None:
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"} raise ValueError(
f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}"
)
default_model = models_list.get(default_model_id) embedding_dimension = default_model.metadata.get("embedding_dimension")
if default_model is None: if embedding_dimension is None:
raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}") raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
embedding_dimension = default_model.metadata.get("embedding_dimension") try:
if embedding_dimension is None: int(embedding_dimension)
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata") except ValueError as err:
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
try: logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
int(embedding_dimension)
except ValueError as err:
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})") # Validate default query expansion model
default_query_expansion_model = vector_stores_config.default_query_expansion_model
if default_query_expansion_model is not None:
provider_id = default_query_expansion_model.provider_id
model_id = default_query_expansion_model.model_id
query_model_id = f"{provider_id}/{model_id}"
if Api.models not in impls:
raise ValueError(
f"Models API is not available but vector_stores config requires query expansion model '{query_model_id}'"
)
models_impl = impls[Api.models]
response = await models_impl.list_models()
llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"}
query_expansion_model = llm_models_list.get(query_model_id)
if query_expansion_model is None:
raise ValueError(
f"Query expansion model '{query_model_id}' not found. Available LLM models: {list(llm_models_list.keys())}"
)
logger.debug(f"Validated default query expansion model: {query_model_id}")
async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]): async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]):
@ -437,6 +464,12 @@ class Stack:
await refresh_registry_once(impls) await refresh_registry_once(impls)
await validate_vector_stores_config(self.run_config.vector_stores, impls) await validate_vector_stores_config(self.run_config.vector_stores, impls)
await validate_safety_config(self.run_config.safety, impls) await validate_safety_config(self.run_config.safety, impls)
# Set global query expansion configuration from stack config
from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config
set_default_query_expansion_config(self.run_config.vector_stores)
self.impls = impls self.impls = impls
def create_registry_refresh_task(self): def create_registry_refresh_task(self):

View file

@ -296,5 +296,7 @@ vector_stores:
Improved query:' Improved query:'
query_expansion_max_tokens: 100
query_expansion_temperature: 0.3
safety: safety:
default_shield_id: llama-guard default_shield_id: llama-guard

View file

@ -287,5 +287,7 @@ vector_stores:
Improved query:' Improved query:'
query_expansion_max_tokens: 100
query_expansion_temperature: 0.3
safety: safety:
default_shield_id: llama-guard default_shield_id: llama-guard

View file

@ -299,5 +299,7 @@ vector_stores:
Improved query:' Improved query:'
query_expansion_max_tokens: 100
query_expansion_temperature: 0.3
safety: safety:
default_shield_id: llama-guard default_shield_id: llama-guard

View file

@ -290,5 +290,7 @@ vector_stores:
Improved query:' Improved query:'
query_expansion_max_tokens: 100
query_expansion_temperature: 0.3
safety: safety:
default_shield_id: llama-guard default_shield_id: llama-guard

View file

@ -296,5 +296,7 @@ vector_stores:
Improved query:' Improved query:'
query_expansion_max_tokens: 100
query_expansion_temperature: 0.3
safety: safety:
default_shield_id: llama-guard default_shield_id: llama-guard

View file

@ -287,5 +287,7 @@ vector_stores:
Improved query:' Improved query:'
query_expansion_max_tokens: 100
query_expansion_temperature: 0.3
safety: safety:
default_shield_id: llama-guard default_shield_id: llama-guard

View file

@ -6,19 +6,16 @@
from typing import Any from typing import Any
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api from llama_stack_api import Api
from .config import FaissVectorIOConfig from .config import FaissVectorIOConfig
async def get_provider_impl( async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
config: FaissVectorIOConfig, deps: dict[Api, Any], vector_stores_config: VectorStoresConfig | None = None
):
from .faiss import FaissVectorIOAdapter from .faiss import FaissVectorIOAdapter
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config) impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -14,7 +14,6 @@ import faiss # type: ignore[import-untyped]
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@ -190,12 +189,10 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
config: FaissVectorIOConfig, config: FaissVectorIOConfig,
inference_api: Inference, inference_api: Inference,
files_api: Files | None, files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None: ) -> None:
super().__init__(files_api=files_api, kvstore=None) super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.cache: dict[str, VectorStoreWithIndex] = {} self.cache: dict[str, VectorStoreWithIndex] = {}
async def initialize(self) -> None: async def initialize(self) -> None:
@ -211,7 +208,6 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
vector_store, vector_store,
await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier), await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
self.inference_api, self.inference_api,
self.vector_stores_config,
) )
self.cache[vector_store.identifier] = index self.cache[vector_store.identifier] = index
@ -250,7 +246,6 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
vector_store=vector_store, vector_store=vector_store,
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier), index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
async def list_vector_stores(self) -> list[VectorStore]: async def list_vector_stores(self) -> list[VectorStore]:
@ -284,7 +279,6 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
vector_store=vector_store, vector_store=vector_store,
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier), index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
self.cache[vector_store_id] = index self.cache[vector_store_id] = index
return index return index

View file

@ -6,18 +6,15 @@
from typing import Any from typing import Any
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api from llama_stack_api import Api
from .config import SQLiteVectorIOConfig from .config import SQLiteVectorIOConfig
async def get_provider_impl( async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
config: SQLiteVectorIOConfig, deps: dict[Api, Any], vector_stores_config: VectorStoresConfig | None = None
):
from .sqlite_vec import SQLiteVecVectorIOAdapter from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config) impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -14,7 +14,6 @@ import numpy as np
import sqlite_vec # type: ignore[import-untyped] import sqlite_vec # type: ignore[import-untyped]
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@ -391,12 +390,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
config, config,
inference_api: Inference, inference_api: Inference,
files_api: Files | None, files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None: ) -> None:
super().__init__(files_api=files_api, kvstore=None) super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.cache: dict[str, VectorStoreWithIndex] = {} self.cache: dict[str, VectorStoreWithIndex] = {}
self.vector_store_table = None self.vector_store_table = None
@ -411,9 +408,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
index = await SQLiteVecIndex.create( index = await SQLiteVecIndex.create(
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
) )
self.cache[vector_store.identifier] = VectorStoreWithIndex( self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
vector_store, index, self.inference_api, self.vector_stores_config
)
# Load existing OpenAI vector stores into the in-memory cache # Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores() await self.initialize_openai_vector_stores()
@ -437,9 +432,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
index = await SQLiteVecIndex.create( index = await SQLiteVecIndex.create(
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
) )
self.cache[vector_store.identifier] = VectorStoreWithIndex( self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
vector_store, index, self.inference_api, self.vector_stores_config
)
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None: async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
if vector_store_id in self.cache: if vector_store_id in self.cache:
@ -464,7 +457,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
kvstore=self.kvstore, kvstore=self.kvstore,
), ),
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
self.cache[vector_store_id] = index self.cache[vector_store_id] = index
return index return index

View file

@ -4,17 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec from llama_stack_api import Api, ProviderSpec
from .config import ChromaVectorIOConfig from .config import ChromaVectorIOConfig
async def get_adapter_impl( async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]):
config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
):
from .chroma import ChromaVectorIOAdapter from .chroma import ChromaVectorIOAdapter
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config) impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -11,7 +11,6 @@ from urllib.parse import urlparse
import chromadb import chromadb
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
@ -126,13 +125,11 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig, config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
inference_api: Inference, inference_api: Inference,
files_api: Files | None, files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None: ) -> None:
super().__init__(files_api=files_api, kvstore=None) super().__init__(files_api=files_api, kvstore=None)
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}") log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.client = None self.client = None
self.cache = {} self.cache = {}
self.vector_store_table = None self.vector_store_table = None
@ -165,7 +162,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
) )
) )
self.cache[vector_store.identifier] = VectorStoreWithIndex( self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store, ChromaIndex(self.client, collection), self.inference_api, self.vector_stores_config vector_store, ChromaIndex(self.client, collection), self.inference_api
) )
async def unregister_vector_store(self, vector_store_id: str) -> None: async def unregister_vector_store(self, vector_store_id: str) -> None:
@ -210,9 +207,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
collection = await maybe_await(self.client.get_collection(vector_store_id)) collection = await maybe_await(self.client.get_collection(vector_store_id))
if not collection: if not collection:
raise ValueError(f"Vector DB {vector_store_id} not found in Chroma") raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
index = VectorStoreWithIndex( index = VectorStoreWithIndex(vector_store, ChromaIndex(self.client, collection), self.inference_api)
vector_store, ChromaIndex(self.client, collection), self.inference_api, self.vector_stores_config
)
self.cache[vector_store_id] = index self.cache[vector_store_id] = index
return index return index

View file

@ -4,18 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec from llama_stack_api import Api, ProviderSpec
from .config import MilvusVectorIOConfig from .config import MilvusVectorIOConfig
async def get_adapter_impl( async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec]):
config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
):
from .milvus import MilvusVectorIOAdapter from .milvus import MilvusVectorIOAdapter
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config) impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -11,7 +11,6 @@ from typing import Any
from numpy.typing import NDArray from numpy.typing import NDArray
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
@ -273,14 +272,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig, config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
inference_api: Inference, inference_api: Inference,
files_api: Files | None, files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None: ) -> None:
super().__init__(files_api=files_api, kvstore=None) super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.cache = {} self.cache = {}
self.client = None self.client = None
self.inference_api = inference_api self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.vector_store_table = None self.vector_store_table = None
self.metadata_collection_name = "openai_vector_stores_metadata" self.metadata_collection_name = "openai_vector_stores_metadata"
@ -301,7 +298,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
kvstore=self.kvstore, kvstore=self.kvstore,
), ),
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
self.cache[vector_store.identifier] = index self.cache[vector_store.identifier] = index
if isinstance(self.config, RemoteMilvusVectorIOConfig): if isinstance(self.config, RemoteMilvusVectorIOConfig):
@ -329,7 +325,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store=vector_store, vector_store=vector_store,
index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level), index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level),
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
self.cache[vector_store.identifier] = index self.cache[vector_store.identifier] = index
@ -352,7 +347,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store=vector_store, vector_store=vector_store,
index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore), index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore),
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
self.cache[vector_store_id] = index self.cache[vector_store_id] = index
return index return index

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec from llama_stack_api import Api, ProviderSpec
from .config import PGVectorVectorIOConfig from .config import PGVectorVectorIOConfig
@ -13,10 +12,9 @@ from .config import PGVectorVectorIOConfig
async def get_adapter_impl( async def get_adapter_impl(
config: PGVectorVectorIOConfig, config: PGVectorVectorIOConfig,
deps: dict[Api, ProviderSpec], deps: dict[Api, ProviderSpec],
vector_stores_config: VectorStoresConfig | None = None,
): ):
from .pgvector import PGVectorVectorIOAdapter from .pgvector import PGVectorVectorIOAdapter
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config) impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -13,7 +13,6 @@ from psycopg2 import sql
from psycopg2.extras import Json, execute_values from psycopg2.extras import Json, execute_values
from pydantic import BaseModel, TypeAdapter from pydantic import BaseModel, TypeAdapter
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
@ -335,12 +334,10 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
config: PGVectorVectorIOConfig, config: PGVectorVectorIOConfig,
inference_api: Inference, inference_api: Inference,
files_api: Files | None = None, files_api: Files | None = None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None: ) -> None:
super().__init__(files_api=files_api, kvstore=None) super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.conn = None self.conn = None
self.cache = {} self.cache = {}
self.vector_store_table = None self.vector_store_table = None
@ -396,7 +393,6 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
vector_store, vector_store,
index=pgvector_index, index=pgvector_index,
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
self.cache[vector_store.identifier] = index self.cache[vector_store.identifier] = index
@ -428,7 +424,6 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
vector_store, vector_store,
index=pgvector_index, index=pgvector_index,
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
self.cache[vector_store.identifier] = index self.cache[vector_store.identifier] = index
@ -469,9 +464,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
vector_store = VectorStore.model_validate_json(vector_store_data) vector_store = VectorStore.model_validate_json(vector_store_data)
index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn) index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn)
await index.initialize() await index.initialize()
self.cache[vector_store_id] = VectorStoreWithIndex( self.cache[vector_store_id] = VectorStoreWithIndex(vector_store, index, self.inference_api)
vector_store, index, self.inference_api, self.vector_stores_config
)
return self.cache[vector_store_id] return self.cache[vector_store_id]
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:

View file

@ -4,17 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec from llama_stack_api import Api, ProviderSpec
from .config import QdrantVectorIOConfig from .config import QdrantVectorIOConfig
async def get_adapter_impl( async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
):
from .qdrant import QdrantVectorIOAdapter from .qdrant import QdrantVectorIOAdapter
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config) impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -13,7 +13,6 @@ from numpy.typing import NDArray
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
from qdrant_client.models import PointStruct from qdrant_client.models import PointStruct
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
@ -153,14 +152,12 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig, config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
inference_api: Inference, inference_api: Inference,
files_api: Files | None = None, files_api: Files | None = None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None: ) -> None:
super().__init__(files_api=files_api, kvstore=None) super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.client: AsyncQdrantClient = None self.client: AsyncQdrantClient = None
self.cache = {} self.cache = {}
self.inference_api = inference_api self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.vector_store_table = None self.vector_store_table = None
self._qdrant_lock = asyncio.Lock() self._qdrant_lock = asyncio.Lock()
@ -179,7 +176,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store, vector_store,
QdrantIndex(self.client, vector_store.identifier), QdrantIndex(self.client, vector_store.identifier),
self.inference_api, self.inference_api,
self.vector_stores_config,
) )
self.cache[vector_store.identifier] = index self.cache[vector_store.identifier] = index
self.openai_vector_stores = await self._load_openai_vector_stores() self.openai_vector_stores = await self._load_openai_vector_stores()
@ -199,7 +195,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store=vector_store, vector_store=vector_store,
index=QdrantIndex(self.client, vector_store.identifier), index=QdrantIndex(self.client, vector_store.identifier),
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
self.cache[vector_store.identifier] = index self.cache[vector_store.identifier] = index
@ -231,7 +226,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store=vector_store, vector_store=vector_store,
index=QdrantIndex(client=self.client, collection_name=vector_store.identifier), index=QdrantIndex(client=self.client, collection_name=vector_store.identifier),
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
self.cache[vector_store_id] = index self.cache[vector_store_id] = index
return index return index

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec from llama_stack_api import Api, ProviderSpec
from .config import WeaviateVectorIOConfig from .config import WeaviateVectorIOConfig
@ -13,10 +12,9 @@ from .config import WeaviateVectorIOConfig
async def get_adapter_impl( async def get_adapter_impl(
config: WeaviateVectorIOConfig, config: WeaviateVectorIOConfig,
deps: dict[Api, ProviderSpec], deps: dict[Api, ProviderSpec],
vector_stores_config: VectorStoresConfig | None = None,
): ):
from .weaviate import WeaviateVectorIOAdapter from .weaviate import WeaviateVectorIOAdapter
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config) impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -12,7 +12,6 @@ from numpy.typing import NDArray
from weaviate.classes.init import Auth from weaviate.classes.init import Auth
from weaviate.classes.query import Filter, HybridFusion from weaviate.classes.query import Filter, HybridFusion
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -268,12 +267,10 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
config: WeaviateVectorIOConfig, config: WeaviateVectorIOConfig,
inference_api: Inference, inference_api: Inference,
files_api: Files | None, files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None: ) -> None:
super().__init__(files_api=files_api, kvstore=None) super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.client_cache = {} self.client_cache = {}
self.cache = {} self.cache = {}
self.vector_store_table = None self.vector_store_table = None
@ -319,7 +316,6 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
vector_store=vector_store, vector_store=vector_store,
index=idx, index=idx,
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
# Load OpenAI vector stores metadata into cache # Load OpenAI vector stores metadata into cache
@ -348,7 +344,6 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
vector_store, vector_store,
WeaviateIndex(client=client, collection_name=sanitized_collection_name), WeaviateIndex(client=client, collection_name=sanitized_collection_name),
self.inference_api, self.inference_api,
self.vector_stores_config,
) )
async def unregister_vector_store(self, vector_store_id: str) -> None: async def unregister_vector_store(self, vector_store_id: str) -> None:
@ -383,7 +378,6 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
vector_store=vector_store, vector_store=vector_store,
index=WeaviateIndex(client=client, collection_name=vector_store.identifier), index=WeaviateIndex(client=client, collection_name=vector_store.identifier),
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
self.cache[vector_store_id] = index self.cache[vector_store_id] = index
return index return index

View file

@ -3,3 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .constants import DEFAULT_QUERY_EXPANSION_PROMPT
__all__ = ["DEFAULT_QUERY_EXPANSION_PROMPT"]

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Default prompt template for query expansion in vector search
DEFAULT_QUERY_EXPANSION_PROMPT = "Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:"

View file

@ -379,11 +379,6 @@ class OpenAIVectorStoreMixin(ABC):
f"Using embedding config from extra_body: model='{embedding_model}', dimension={embedding_dimension}" f"Using embedding config from extra_body: model='{embedding_model}', dimension={embedding_dimension}"
) )
# Extract query expansion model from extra_body if provided
query_expansion_model = extra_body.get("query_expansion_model")
if query_expansion_model:
logger.debug(f"Using per-store query expansion model: {query_expansion_model}")
# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config # use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None) provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None)
# Derive the canonical vector_store_id (allow override, else generate) # Derive the canonical vector_store_id (allow override, else generate)
@ -407,7 +402,6 @@ class OpenAIVectorStoreMixin(ABC):
provider_id=provider_id, provider_id=provider_id,
provider_resource_id=vector_store_id, provider_resource_id=vector_store_id,
vector_store_name=params.name, vector_store_name=params.name,
query_expansion_model=query_expansion_model,
) )
await self.register_vector_store(vector_store) await self.register_vector_store(vector_store)
@ -621,9 +615,6 @@ class OpenAIVectorStoreMixin(ABC):
"rewrite_query": rewrite_query, "rewrite_query": rewrite_query,
} }
# Add vector_stores_config if available (for query rewriting)
if hasattr(self, "vector_stores_config"):
params["vector_stores_config"] = self.vector_stores_config
# TODO: Add support for ranking_options.ranker # TODO: Add support for ranking_options.ranker
response = await self.query_chunks( response = await self.query_chunks(

View file

@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
# Global configuration for query expansion - set during stack startup
_DEFAULT_QUERY_EXPANSION_MODEL: QualifiedModel | None = None
_DEFAULT_QUERY_EXPANSION_MAX_TOKENS: int = 100
_DEFAULT_QUERY_EXPANSION_TEMPERATURE: float = 0.3
_QUERY_EXPANSION_PROMPT_OVERRIDE: str | None = None
def set_default_query_expansion_config(vector_stores_config: VectorStoresConfig | None):
"""Set the global default query expansion configuration from stack config."""
global \
_DEFAULT_QUERY_EXPANSION_MODEL, \
_QUERY_EXPANSION_PROMPT_OVERRIDE, \
_DEFAULT_QUERY_EXPANSION_MAX_TOKENS, \
_DEFAULT_QUERY_EXPANSION_TEMPERATURE
if vector_stores_config:
_DEFAULT_QUERY_EXPANSION_MODEL = vector_stores_config.default_query_expansion_model
# Only set override if user provided a custom prompt different from default
if vector_stores_config.query_expansion_prompt != DEFAULT_QUERY_EXPANSION_PROMPT:
_QUERY_EXPANSION_PROMPT_OVERRIDE = vector_stores_config.query_expansion_prompt
else:
_QUERY_EXPANSION_PROMPT_OVERRIDE = None
_DEFAULT_QUERY_EXPANSION_MAX_TOKENS = vector_stores_config.query_expansion_max_tokens
_DEFAULT_QUERY_EXPANSION_TEMPERATURE = vector_stores_config.query_expansion_temperature
else:
_DEFAULT_QUERY_EXPANSION_MODEL = None
_QUERY_EXPANSION_PROMPT_OVERRIDE = None
_DEFAULT_QUERY_EXPANSION_MAX_TOKENS = 100
_DEFAULT_QUERY_EXPANSION_TEMPERATURE = 0.3

View file

@ -17,7 +17,6 @@ import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
@ -30,19 +29,18 @@ from llama_stack_api import (
Chunk, Chunk,
ChunkMetadata, ChunkMetadata,
InterleavedContent, InterleavedContent,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody, OpenAIEmbeddingsRequestWithExtraBody,
QueryChunksResponse, QueryChunksResponse,
RAGDocument, RAGDocument,
VectorStore, VectorStore,
) )
from llama_stack_api.inference import (
OpenAIChatCompletionRequestWithExtraBody,
OpenAIUserMessageParam,
)
from llama_stack_api.models import ModelType
log = get_logger(name=__name__, category="providers::utils") log = get_logger(name=__name__, category="providers::utils")
from llama_stack.providers.utils.memory import query_expansion_config
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
class ChunkForDeletion(BaseModel): class ChunkForDeletion(BaseModel):
"""Information needed to delete a chunk from a vector store. """Information needed to delete a chunk from a vector store.
@ -268,7 +266,6 @@ class VectorStoreWithIndex:
vector_store: VectorStore vector_store: VectorStore
index: EmbeddingIndex index: EmbeddingIndex
inference_api: Api.inference inference_api: Api.inference
vector_stores_config: VectorStoresConfig | None = None
async def insert_chunks( async def insert_chunks(
self, self,
@ -296,6 +293,39 @@ class VectorStoreWithIndex:
embeddings = np.array([c.embedding for c in chunks], dtype=np.float32) embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
await self.index.add_chunks(chunks, embeddings) await self.index.add_chunks(chunks, embeddings)
async def _rewrite_query_for_file_search(self, query: str) -> str:
"""Rewrite a search query using the globally configured LLM model for better retrieval results."""
if not query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL:
log.debug("No default query expansion model configured, using original query")
return query
model_id = f"{query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL.provider_id}/{query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL.model_id}"
# Use custom prompt from config if provided, otherwise use built-in default
# Users only need to configure the model - prompt is automatic with optional override
if query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE:
# Custom prompt from config - format if it contains {query} placeholder
prompt = (
query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE.format(query=query)
if "{query}" in query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE
else query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE
)
else:
# Use built-in default prompt and format with query
prompt = DEFAULT_QUERY_EXPANSION_PROMPT.format(query=query)
request = OpenAIChatCompletionRequestWithExtraBody(
model=model_id,
messages=[{"role": "user", "content": prompt}],
max_tokens=query_expansion_config._DEFAULT_QUERY_EXPANSION_MAX_TOKENS,
temperature=query_expansion_config._DEFAULT_QUERY_EXPANSION_TEMPERATURE,
)
response = await self.inference_api.openai_chat_completion(request)
rewritten_query = response.choices[0].message.content.strip()
log.debug(f"Query rewritten: '{query}''{rewritten_query}'")
return rewritten_query
async def query_chunks( async def query_chunks(
self, self,
query: InterleavedContent, query: InterleavedContent,
@ -304,10 +334,6 @@ class VectorStoreWithIndex:
if params is None: if params is None:
params = {} params = {}
# Extract configuration if provided by router
if "vector_stores_config" in params:
self.vector_stores_config = params["vector_stores_config"]
k = params.get("max_chunks", 3) k = params.get("max_chunks", 3)
mode = params.get("mode") mode = params.get("mode")
score_threshold = params.get("score_threshold", 0.0) score_threshold = params.get("score_threshold", 0.0)
@ -331,18 +357,9 @@ class VectorStoreWithIndex:
query_string = interleaved_content_as_str(query) query_string = interleaved_content_as_str(query)
# Apply query rewriting if enabled # Apply query rewriting if enabled and model is configured
if params.get("rewrite_query", False): if params.get("rewrite_query", False):
if self.vector_stores_config: query_string = await self._rewrite_query_for_file_search(query_string)
log.debug(f"VectorStoreWithIndex received config: {self.vector_stores_config}")
if hasattr(self.vector_stores_config, "default_query_expansion_model"):
log.debug(
f"Config has default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
)
else:
log.debug("No vector_stores_config found - cannot perform query rewriting")
query_string = await self._rewrite_query_for_search(query_string)
if mode == "keyword": if mode == "keyword":
return await self.index.query_keyword(query_string, k, score_threshold) return await self.index.query_keyword(query_string, k, score_threshold)
@ -359,88 +376,3 @@ class VectorStoreWithIndex:
) )
else: else:
return await self.index.query_vector(query_vector, k, score_threshold) return await self.index.query_vector(query_vector, k, score_threshold)
async def _rewrite_query_for_search(self, query: str) -> str:
"""Rewrite the user query to improve vector search performance.
:param query: The original user query
:returns: The rewritten query optimized for vector search
"""
expansion_model = None
# Check for per-store query expansion model first
if self.vector_store.query_expansion_model:
# Parse the model string into provider_id and model_id
model_parts = self.vector_store.query_expansion_model.split("/", 1)
if len(model_parts) == 2:
expansion_model = QualifiedModel(provider_id=model_parts[0], model_id=model_parts[1])
log.debug(f"Using per-store query expansion model: {expansion_model}")
else:
log.warning(
f"Invalid query_expansion_model format: {self.vector_store.query_expansion_model}. Expected 'provider_id/model_id'"
)
# Fall back to global default if no per-store model
if not expansion_model:
if not self.vector_stores_config:
raise ValueError(
f"No vector_stores_config found and no per-store query_expansion_model! self.vector_stores_config is: {self.vector_stores_config}"
)
if not self.vector_stores_config.default_query_expansion_model:
raise ValueError(
f"No default_query_expansion_model configured and no per-store query_expansion_model! vector_stores_config: {self.vector_stores_config}, default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
)
expansion_model = self.vector_stores_config.default_query_expansion_model
log.debug(f"Using global default query expansion model: {expansion_model}")
chat_model = f"{expansion_model.provider_id}/{expansion_model.model_id}"
# Validate that the model is available and is an LLM
try:
models_response = await self.inference_api.routing_table.list_models()
except Exception as e:
raise RuntimeError(f"Failed to list available models for validation: {e}") from e
model_found = False
for model in models_response.data:
if model.identifier == chat_model:
if model.model_type != ModelType.llm:
raise ValueError(
f"Configured query expansion model '{chat_model}' is not an LLM model "
f"(found type: {model.model_type}). Query rewriting requires an LLM model."
)
model_found = True
break
if not model_found:
available_llm_models = [m.identifier for m in models_response.data if m.model_type == ModelType.llm]
raise ValueError(
f"Configured query expansion model '{chat_model}' is not available. "
f"Available LLM models: {available_llm_models}"
)
# Use the configured prompt (has a default value)
rewrite_prompt = self.vector_stores_config.query_expansion_prompt.format(query=query)
chat_request = OpenAIChatCompletionRequestWithExtraBody(
model=chat_model,
messages=[
OpenAIUserMessageParam(
role="user",
content=rewrite_prompt,
)
],
max_tokens=100,
)
try:
response = await self.inference_api.openai_chat_completion(chat_request)
except Exception as e:
raise RuntimeError(f"Failed to generate rewritten query: {e}") from e
if response.choices and len(response.choices) > 0:
rewritten_query = response.choices[0].message.content.strip()
log.info(f"Query rewritten: '{query}''{rewritten_query}'")
return rewritten_query
else:
raise RuntimeError("No response received from LLM model for query rewriting")

View file

@ -25,7 +25,6 @@ class VectorStore(Resource):
embedding_model: str embedding_model: str
embedding_dimension: int embedding_dimension: int
vector_store_name: str | None = None vector_store_name: str | None = None
query_expansion_model: str | None = None
@property @property
def vector_store_id(self) -> str: def vector_store_id(self) -> str:

View file

@ -1233,94 +1233,122 @@ async def test_embedding_config_required_model_missing(vector_io_adapter):
async def test_query_expansion_functionality(vector_io_adapter): async def test_query_expansion_functionality(vector_io_adapter):
"""Test query expansion with per-store models, global defaults, and error validation.""" """Test query expansion with simplified global configuration approach."""
from unittest.mock import MagicMock from unittest.mock import MagicMock
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig from llama_stack.core.datatypes import QualifiedModel
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config
from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex
from llama_stack_api.models import Model, ModelType from llama_stack_api import QueryChunksResponse
vector_io_adapter.register_vector_store = AsyncMock() # Mock a simple vector store and index
vector_io_adapter.__provider_id__ = "test_provider"
# Test 1: Per-store model usage
params = OpenAICreateVectorStoreRequestWithExtraBody(
name="test_store",
metadata={},
**{"embedding_model": "test/embedding", "query_expansion_model": "test/llama-model"},
)
await vector_io_adapter.openai_create_vector_store(params)
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.query_expansion_model == "test/llama-model"
# Test 2: Global default fallback
vector_io_adapter.register_vector_store.reset_mock()
params_no_model = OpenAICreateVectorStoreRequestWithExtraBody(
name="test_store2", metadata={}, **{"embedding_model": "test/embedding"}
)
await vector_io_adapter.openai_create_vector_store(params_no_model)
call_args2 = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args2.query_expansion_model is None
# Test query rewriting scenarios
mock_inference_api = MagicMock()
# Per-store model scenario
mock_vector_store = MagicMock() mock_vector_store = MagicMock()
mock_vector_store.query_expansion_model = "test/llama-model" mock_vector_store.embedding_model = "test/embedding"
mock_inference_api.routing_table.list_models = AsyncMock( mock_inference_api = MagicMock()
return_value=MagicMock( mock_index = MagicMock()
data=[Model(identifier="test/llama-model", provider_id="test", model_type=ModelType.llm)]
)
)
mock_inference_api.openai_chat_completion = AsyncMock(
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="per-store expanded"))])
)
# Create VectorStoreWithIndex with simplified constructor
vector_store_with_index = VectorStoreWithIndex( vector_store_with_index = VectorStoreWithIndex(
vector_store=mock_vector_store, vector_store=mock_vector_store,
index=MagicMock(), index=mock_index,
inference_api=mock_inference_api, inference_api=mock_inference_api,
vector_stores_config=VectorStoresConfig(
default_query_expansion_model=QualifiedModel(provider_id="global", model_id="default")
),
) )
result = await vector_store_with_index._rewrite_query_for_search("test") # Mock the query_vector method to return a simple response
assert mock_inference_api.openai_chat_completion.call_args[0][0].model == "test/llama-model" mock_response = QueryChunksResponse(chunks=[], scores=[])
assert result == "per-store expanded" mock_index.query_vector = AsyncMock(return_value=mock_response)
# Global default fallback scenario # Mock embeddings generation
mock_inference_api.reset_mock() mock_inference_api.openai_embeddings = AsyncMock(
mock_vector_store.query_expansion_model = None return_value=MagicMock(data=[MagicMock(embedding=[0.1, 0.2, 0.3])])
mock_inference_api.routing_table.list_models = AsyncMock(
return_value=MagicMock(
data=[Model(identifier="global/default", provider_id="global", model_type=ModelType.llm)]
)
) )
# Test 1: Query expansion with default prompt (no custom prompt configured)
mock_vector_stores_config = MagicMock()
mock_vector_stores_config.default_query_expansion_model = QualifiedModel(provider_id="test", model_id="llama")
mock_vector_stores_config.query_expansion_prompt = None # Use built-in default prompt
mock_vector_stores_config.query_expansion_max_tokens = 100 # Default value
mock_vector_stores_config.query_expansion_temperature = 0.3 # Default value
# Set global config
set_default_query_expansion_config(mock_vector_stores_config)
# Mock chat completion for query rewriting
mock_inference_api.openai_chat_completion = AsyncMock( mock_inference_api.openai_chat_completion = AsyncMock(
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="global expanded"))]) return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="expanded test query"))])
) )
result = await vector_store_with_index._rewrite_query_for_search("test") params = {"rewrite_query": True, "max_chunks": 5}
assert mock_inference_api.openai_chat_completion.call_args[0][0].model == "global/default" result = await vector_store_with_index.query_chunks("test query", params)
assert result == "global expanded"
# Test 3: Error cases # Verify chat completion was called for query rewriting
# Model not found mock_inference_api.openai_chat_completion.assert_called_once()
mock_vector_store.query_expansion_model = "missing/model" chat_call_args = mock_inference_api.openai_chat_completion.call_args[0][0]
mock_inference_api.routing_table.list_models = AsyncMock(return_value=MagicMock(data=[])) assert chat_call_args.model == "test/llama"
with pytest.raises(ValueError, match="Configured query expansion model .* is not available"): # Verify default prompt is used (contains our built-in prompt text)
await vector_store_with_index._rewrite_query_for_search("test") prompt_text = chat_call_args.messages[0].content
expected_prompt = DEFAULT_QUERY_EXPANSION_PROMPT.format(query="test query")
assert prompt_text == expected_prompt
# Non-LLM model # Verify default inference parameters are used
mock_vector_store.query_expansion_model = "test/embedding-model" assert chat_call_args.max_tokens == 100 # Default value
mock_inference_api.routing_table.list_models = AsyncMock( assert chat_call_args.temperature == 0.3 # Default value
return_value=MagicMock(
data=[Model(identifier="test/embedding-model", provider_id="test", model_type=ModelType.embedding)]
)
)
with pytest.raises(ValueError, match="is not an LLM model.*Query rewriting requires an LLM model"): # Verify the rest of the flow proceeded normally
await vector_store_with_index._rewrite_query_for_search("test") mock_inference_api.openai_embeddings.assert_called_once()
mock_index.query_vector.assert_called_once()
assert result == mock_response
# Test 1b: Query expansion with custom prompt override and inference parameters
mock_inference_api.reset_mock()
mock_index.reset_mock()
mock_vector_stores_config.query_expansion_prompt = "Custom prompt for rewriting: {query}"
mock_vector_stores_config.query_expansion_max_tokens = 150
mock_vector_stores_config.query_expansion_temperature = 0.7
set_default_query_expansion_config(mock_vector_stores_config)
result = await vector_store_with_index.query_chunks("test query", params)
# Verify custom prompt and parameters are used
mock_inference_api.openai_chat_completion.assert_called_once()
chat_call_args = mock_inference_api.openai_chat_completion.call_args[0][0]
prompt_text = chat_call_args.messages[0].content
assert prompt_text == "Custom prompt for rewriting: test query"
assert "Expand this query with relevant synonyms" not in prompt_text # Default not used
# Verify custom inference parameters
assert chat_call_args.max_tokens == 150
assert chat_call_args.temperature == 0.7
# Test 2: No query expansion when no global model is configured
mock_inference_api.reset_mock()
mock_index.reset_mock()
# Clear global config
set_default_query_expansion_config(None)
params = {"rewrite_query": True, "max_chunks": 5}
result2 = await vector_store_with_index.query_chunks("test query", params)
# Verify chat completion was NOT called
mock_inference_api.openai_chat_completion.assert_not_called()
# But normal flow should still work
mock_inference_api.openai_embeddings.assert_called_once()
mock_index.query_vector.assert_called_once()
assert result2 == mock_response
# Test 3: Normal behavior without rewrite_query parameter
mock_inference_api.reset_mock()
mock_index.reset_mock()
params_no_rewrite = {"max_chunks": 5}
result3 = await vector_store_with_index.query_chunks("test query", params_no_rewrite)
# Neither chat completion nor query rewriting should be called
mock_inference_api.openai_chat_completion.assert_not_called()
mock_inference_api.openai_embeddings.assert_called_once()
mock_index.query_vector.assert_called_once()
assert result3 == mock_response