refactor to only configuration of model at build time

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-11-21 11:27:25 -05:00
parent 2cc7943fd6
commit d887f1f8bb
31 changed files with 280 additions and 315 deletions

View file

@ -18,6 +18,7 @@ from llama_stack.core.storage.datatypes import (
StorageConfig,
)
from llama_stack.log import LoggingConfig
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
from llama_stack_api import (
Api,
Benchmark,
@ -381,9 +382,17 @@ class VectorStoresConfig(BaseModel):
description="Default LLM model for query expansion/rewriting in vector search.",
)
query_expansion_prompt: str = Field(
default="Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:",
default=DEFAULT_QUERY_EXPANSION_PROMPT,
description="Prompt template for query expansion. Use {query} as placeholder for the original query.",
)
query_expansion_max_tokens: int = Field(
default=100,
description="Maximum number of tokens for query expansion responses.",
)
query_expansion_temperature: float = Field(
default=0.3,
description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).",
)
class SafetyConfig(BaseModel):

View file

@ -374,13 +374,6 @@ async def instantiate_provider(
method = "get_adapter_impl"
args = [config, deps]
# Add vector_stores_config for vector_io providers
if (
"vector_stores_config" in inspect.signature(getattr(module, method)).parameters
and provider_spec.api == Api.vector_io
):
args.append(run_config.vector_stores)
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
@ -401,11 +394,6 @@ async def instantiate_provider(
args.append(policy)
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
args.append(run_config.telemetry.enabled)
if (
"vector_stores_config" in inspect.signature(getattr(module, method)).parameters
and provider_spec.api == Api.vector_io
):
args.append(run_config.vector_stores)
fn = getattr(module, method)
impl = await fn(*args)

View file

@ -99,19 +99,6 @@ class VectorIORouter(VectorIO):
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
# Ensure params dict exists and add vector_stores_config for query rewriting
if params is None:
params = {}
logger.debug(f"Router vector_stores_config: {self.vector_stores_config}")
if self.vector_stores_config and hasattr(self.vector_stores_config, "default_query_expansion_model"):
logger.debug(
f"Router default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
)
params["vector_stores_config"] = self.vector_stores_config
return await provider.query_chunks(vector_store_id, query, params)
# OpenAI Vector Stores API endpoints

View file

@ -144,35 +144,62 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig
if vector_stores_config is None:
return
# Validate default embedding model
default_embedding_model = vector_stores_config.default_embedding_model
if default_embedding_model is None:
return
if default_embedding_model is not None:
provider_id = default_embedding_model.provider_id
model_id = default_embedding_model.model_id
default_model_id = f"{provider_id}/{model_id}"
provider_id = default_embedding_model.provider_id
model_id = default_embedding_model.model_id
default_model_id = f"{provider_id}/{model_id}"
if Api.models not in impls:
raise ValueError(
f"Models API is not available but vector_stores config requires model '{default_model_id}'"
)
if Api.models not in impls:
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
models_impl = impls[Api.models]
response = await models_impl.list_models()
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
models_impl = impls[Api.models]
response = await models_impl.list_models()
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
default_model = models_list.get(default_model_id)
if default_model is None:
raise ValueError(
f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}"
)
default_model = models_list.get(default_model_id)
if default_model is None:
raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}")
embedding_dimension = default_model.metadata.get("embedding_dimension")
if embedding_dimension is None:
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
embedding_dimension = default_model.metadata.get("embedding_dimension")
if embedding_dimension is None:
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
try:
int(embedding_dimension)
except ValueError as err:
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
try:
int(embedding_dimension)
except ValueError as err:
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
# Validate default query expansion model
default_query_expansion_model = vector_stores_config.default_query_expansion_model
if default_query_expansion_model is not None:
provider_id = default_query_expansion_model.provider_id
model_id = default_query_expansion_model.model_id
query_model_id = f"{provider_id}/{model_id}"
if Api.models not in impls:
raise ValueError(
f"Models API is not available but vector_stores config requires query expansion model '{query_model_id}'"
)
models_impl = impls[Api.models]
response = await models_impl.list_models()
llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"}
query_expansion_model = llm_models_list.get(query_model_id)
if query_expansion_model is None:
raise ValueError(
f"Query expansion model '{query_model_id}' not found. Available LLM models: {list(llm_models_list.keys())}"
)
logger.debug(f"Validated default query expansion model: {query_model_id}")
async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]):
@ -437,6 +464,12 @@ class Stack:
await refresh_registry_once(impls)
await validate_vector_stores_config(self.run_config.vector_stores, impls)
await validate_safety_config(self.run_config.safety, impls)
# Set global query expansion configuration from stack config
from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config
set_default_query_expansion_config(self.run_config.vector_stores)
self.impls = impls
def create_registry_refresh_task(self):

View file

@ -296,5 +296,7 @@ vector_stores:
Improved query:'
query_expansion_max_tokens: 100
query_expansion_temperature: 0.3
safety:
default_shield_id: llama-guard

View file

@ -287,5 +287,7 @@ vector_stores:
Improved query:'
query_expansion_max_tokens: 100
query_expansion_temperature: 0.3
safety:
default_shield_id: llama-guard

View file

@ -299,5 +299,7 @@ vector_stores:
Improved query:'
query_expansion_max_tokens: 100
query_expansion_temperature: 0.3
safety:
default_shield_id: llama-guard

View file

@ -290,5 +290,7 @@ vector_stores:
Improved query:'
query_expansion_max_tokens: 100
query_expansion_temperature: 0.3
safety:
default_shield_id: llama-guard

View file

@ -296,5 +296,7 @@ vector_stores:
Improved query:'
query_expansion_max_tokens: 100
query_expansion_temperature: 0.3
safety:
default_shield_id: llama-guard

View file

@ -287,5 +287,7 @@ vector_stores:
Improved query:'
query_expansion_max_tokens: 100
query_expansion_temperature: 0.3
safety:
default_shield_id: llama-guard

View file

@ -6,19 +6,16 @@
from typing import Any
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api
from .config import FaissVectorIOConfig
async def get_provider_impl(
config: FaissVectorIOConfig, deps: dict[Api, Any], vector_stores_config: VectorStoresConfig | None = None
):
async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
from .faiss import FaissVectorIOAdapter
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -14,7 +14,6 @@ import faiss # type: ignore[import-untyped]
import numpy as np
from numpy.typing import NDArray
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@ -190,12 +189,10 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
config: FaissVectorIOConfig,
inference_api: Inference,
files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.cache: dict[str, VectorStoreWithIndex] = {}
async def initialize(self) -> None:
@ -211,7 +208,6 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
vector_store,
await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
self.inference_api,
self.vector_stores_config,
)
self.cache[vector_store.identifier] = index
@ -250,7 +246,6 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
vector_store=vector_store,
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
async def list_vector_stores(self) -> list[VectorStore]:
@ -284,7 +279,6 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
vector_store=vector_store,
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store_id] = index
return index

View file

@ -6,18 +6,15 @@
from typing import Any
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api
from .config import SQLiteVectorIOConfig
async def get_provider_impl(
config: SQLiteVectorIOConfig, deps: dict[Api, Any], vector_stores_config: VectorStoresConfig | None = None
):
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -14,7 +14,6 @@ import numpy as np
import sqlite_vec # type: ignore[import-untyped]
from numpy.typing import NDArray
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@ -391,12 +390,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
config,
inference_api: Inference,
files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.cache: dict[str, VectorStoreWithIndex] = {}
self.vector_store_table = None
@ -411,9 +408,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
index = await SQLiteVecIndex.create(
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
)
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store, index, self.inference_api, self.vector_stores_config
)
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
# Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores()
@ -437,9 +432,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
index = await SQLiteVecIndex.create(
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
)
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store, index, self.inference_api, self.vector_stores_config
)
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
if vector_store_id in self.cache:
@ -464,7 +457,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
kvstore=self.kvstore,
),
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store_id] = index
return index

View file

@ -4,17 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec
from .config import ChromaVectorIOConfig
async def get_adapter_impl(
config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
):
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .chroma import ChromaVectorIOAdapter
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -11,7 +11,6 @@ from urllib.parse import urlparse
import chromadb
from numpy.typing import NDArray
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
@ -126,13 +125,11 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
inference_api: Inference,
files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
self.config = config
self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.client = None
self.cache = {}
self.vector_store_table = None
@ -165,7 +162,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
)
)
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store, ChromaIndex(self.client, collection), self.inference_api, self.vector_stores_config
vector_store, ChromaIndex(self.client, collection), self.inference_api
)
async def unregister_vector_store(self, vector_store_id: str) -> None:
@ -210,9 +207,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
collection = await maybe_await(self.client.get_collection(vector_store_id))
if not collection:
raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
index = VectorStoreWithIndex(
vector_store, ChromaIndex(self.client, collection), self.inference_api, self.vector_stores_config
)
index = VectorStoreWithIndex(vector_store, ChromaIndex(self.client, collection), self.inference_api)
self.cache[vector_store_id] = index
return index

View file

@ -4,18 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec
from .config import MilvusVectorIOConfig
async def get_adapter_impl(
config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
):
async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .milvus import MilvusVectorIOAdapter
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -11,7 +11,6 @@ from typing import Any
from numpy.typing import NDArray
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
@ -273,14 +272,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
inference_api: Inference,
files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.cache = {}
self.client = None
self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.vector_store_table = None
self.metadata_collection_name = "openai_vector_stores_metadata"
@ -301,7 +298,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
kvstore=self.kvstore,
),
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store.identifier] = index
if isinstance(self.config, RemoteMilvusVectorIOConfig):
@ -329,7 +325,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store=vector_store,
index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level),
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store.identifier] = index
@ -352,7 +347,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store=vector_store,
index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore),
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store_id] = index
return index

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec
from .config import PGVectorVectorIOConfig
@ -13,10 +12,9 @@ from .config import PGVectorVectorIOConfig
async def get_adapter_impl(
config: PGVectorVectorIOConfig,
deps: dict[Api, ProviderSpec],
vector_stores_config: VectorStoresConfig | None = None,
):
from .pgvector import PGVectorVectorIOAdapter
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -13,7 +13,6 @@ from psycopg2 import sql
from psycopg2.extras import Json, execute_values
from pydantic import BaseModel, TypeAdapter
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
@ -335,12 +334,10 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
config: PGVectorVectorIOConfig,
inference_api: Inference,
files_api: Files | None = None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.conn = None
self.cache = {}
self.vector_store_table = None
@ -396,7 +393,6 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
vector_store,
index=pgvector_index,
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store.identifier] = index
@ -428,7 +424,6 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
vector_store,
index=pgvector_index,
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store.identifier] = index
@ -469,9 +464,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
vector_store = VectorStore.model_validate_json(vector_store_data)
index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn)
await index.initialize()
self.cache[vector_store_id] = VectorStoreWithIndex(
vector_store, index, self.inference_api, self.vector_stores_config
)
self.cache[vector_store_id] = VectorStoreWithIndex(vector_store, index, self.inference_api)
return self.cache[vector_store_id]
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:

View file

@ -4,17 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec
from .config import QdrantVectorIOConfig
async def get_adapter_impl(
config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
):
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorIOAdapter
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -13,7 +13,6 @@ from numpy.typing import NDArray
from qdrant_client import AsyncQdrantClient, models
from qdrant_client.models import PointStruct
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
@ -153,14 +152,12 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
inference_api: Inference,
files_api: Files | None = None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.client: AsyncQdrantClient = None
self.cache = {}
self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.vector_store_table = None
self._qdrant_lock = asyncio.Lock()
@ -179,7 +176,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store,
QdrantIndex(self.client, vector_store.identifier),
self.inference_api,
self.vector_stores_config,
)
self.cache[vector_store.identifier] = index
self.openai_vector_stores = await self._load_openai_vector_stores()
@ -199,7 +195,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store=vector_store,
index=QdrantIndex(self.client, vector_store.identifier),
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store.identifier] = index
@ -231,7 +226,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store=vector_store,
index=QdrantIndex(client=self.client, collection_name=vector_store.identifier),
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store_id] = index
return index

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec
from .config import WeaviateVectorIOConfig
@ -13,10 +12,9 @@ from .config import WeaviateVectorIOConfig
async def get_adapter_impl(
config: WeaviateVectorIOConfig,
deps: dict[Api, ProviderSpec],
vector_stores_config: VectorStoresConfig | None = None,
):
from .weaviate import WeaviateVectorIOAdapter
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -12,7 +12,6 @@ from numpy.typing import NDArray
from weaviate.classes.init import Auth
from weaviate.classes.query import Filter, HybridFusion
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
@ -268,12 +267,10 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
config: WeaviateVectorIOConfig,
inference_api: Inference,
files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.client_cache = {}
self.cache = {}
self.vector_store_table = None
@ -319,7 +316,6 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
vector_store=vector_store,
index=idx,
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
# Load OpenAI vector stores metadata into cache
@ -348,7 +344,6 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
vector_store,
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
self.inference_api,
self.vector_stores_config,
)
async def unregister_vector_store(self, vector_store_id: str) -> None:
@ -383,7 +378,6 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
vector_store=vector_store,
index=WeaviateIndex(client=client, collection_name=vector_store.identifier),
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store_id] = index
return index

View file

@ -3,3 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .constants import DEFAULT_QUERY_EXPANSION_PROMPT
__all__ = ["DEFAULT_QUERY_EXPANSION_PROMPT"]

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Default prompt template for query expansion in vector search
DEFAULT_QUERY_EXPANSION_PROMPT = "Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:"

View file

@ -379,11 +379,6 @@ class OpenAIVectorStoreMixin(ABC):
f"Using embedding config from extra_body: model='{embedding_model}', dimension={embedding_dimension}"
)
# Extract query expansion model from extra_body if provided
query_expansion_model = extra_body.get("query_expansion_model")
if query_expansion_model:
logger.debug(f"Using per-store query expansion model: {query_expansion_model}")
# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None)
# Derive the canonical vector_store_id (allow override, else generate)
@ -407,7 +402,6 @@ class OpenAIVectorStoreMixin(ABC):
provider_id=provider_id,
provider_resource_id=vector_store_id,
vector_store_name=params.name,
query_expansion_model=query_expansion_model,
)
await self.register_vector_store(vector_store)
@ -621,9 +615,6 @@ class OpenAIVectorStoreMixin(ABC):
"rewrite_query": rewrite_query,
}
# Add vector_stores_config if available (for query rewriting)
if hasattr(self, "vector_stores_config"):
params["vector_stores_config"] = self.vector_stores_config
# TODO: Add support for ranking_options.ranker
response = await self.query_chunks(

View file

@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
# Global configuration for query expansion - set during stack startup
_DEFAULT_QUERY_EXPANSION_MODEL: QualifiedModel | None = None
_DEFAULT_QUERY_EXPANSION_MAX_TOKENS: int = 100
_DEFAULT_QUERY_EXPANSION_TEMPERATURE: float = 0.3
_QUERY_EXPANSION_PROMPT_OVERRIDE: str | None = None
def set_default_query_expansion_config(vector_stores_config: VectorStoresConfig | None):
"""Set the global default query expansion configuration from stack config."""
global \
_DEFAULT_QUERY_EXPANSION_MODEL, \
_QUERY_EXPANSION_PROMPT_OVERRIDE, \
_DEFAULT_QUERY_EXPANSION_MAX_TOKENS, \
_DEFAULT_QUERY_EXPANSION_TEMPERATURE
if vector_stores_config:
_DEFAULT_QUERY_EXPANSION_MODEL = vector_stores_config.default_query_expansion_model
# Only set override if user provided a custom prompt different from default
if vector_stores_config.query_expansion_prompt != DEFAULT_QUERY_EXPANSION_PROMPT:
_QUERY_EXPANSION_PROMPT_OVERRIDE = vector_stores_config.query_expansion_prompt
else:
_QUERY_EXPANSION_PROMPT_OVERRIDE = None
_DEFAULT_QUERY_EXPANSION_MAX_TOKENS = vector_stores_config.query_expansion_max_tokens
_DEFAULT_QUERY_EXPANSION_TEMPERATURE = vector_stores_config.query_expansion_temperature
else:
_DEFAULT_QUERY_EXPANSION_MODEL = None
_QUERY_EXPANSION_PROMPT_OVERRIDE = None
_DEFAULT_QUERY_EXPANSION_MAX_TOKENS = 100
_DEFAULT_QUERY_EXPANSION_TEMPERATURE = 0.3

View file

@ -17,7 +17,6 @@ import numpy as np
from numpy.typing import NDArray
from pydantic import BaseModel
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -30,19 +29,18 @@ from llama_stack_api import (
Chunk,
ChunkMetadata,
InterleavedContent,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
QueryChunksResponse,
RAGDocument,
VectorStore,
)
from llama_stack_api.inference import (
OpenAIChatCompletionRequestWithExtraBody,
OpenAIUserMessageParam,
)
from llama_stack_api.models import ModelType
log = get_logger(name=__name__, category="providers::utils")
from llama_stack.providers.utils.memory import query_expansion_config
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
class ChunkForDeletion(BaseModel):
"""Information needed to delete a chunk from a vector store.
@ -268,7 +266,6 @@ class VectorStoreWithIndex:
vector_store: VectorStore
index: EmbeddingIndex
inference_api: Api.inference
vector_stores_config: VectorStoresConfig | None = None
async def insert_chunks(
self,
@ -296,6 +293,39 @@ class VectorStoreWithIndex:
embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
await self.index.add_chunks(chunks, embeddings)
async def _rewrite_query_for_file_search(self, query: str) -> str:
"""Rewrite a search query using the globally configured LLM model for better retrieval results."""
if not query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL:
log.debug("No default query expansion model configured, using original query")
return query
model_id = f"{query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL.provider_id}/{query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL.model_id}"
# Use custom prompt from config if provided, otherwise use built-in default
# Users only need to configure the model - prompt is automatic with optional override
if query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE:
# Custom prompt from config - format if it contains {query} placeholder
prompt = (
query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE.format(query=query)
if "{query}" in query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE
else query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE
)
else:
# Use built-in default prompt and format with query
prompt = DEFAULT_QUERY_EXPANSION_PROMPT.format(query=query)
request = OpenAIChatCompletionRequestWithExtraBody(
model=model_id,
messages=[{"role": "user", "content": prompt}],
max_tokens=query_expansion_config._DEFAULT_QUERY_EXPANSION_MAX_TOKENS,
temperature=query_expansion_config._DEFAULT_QUERY_EXPANSION_TEMPERATURE,
)
response = await self.inference_api.openai_chat_completion(request)
rewritten_query = response.choices[0].message.content.strip()
log.debug(f"Query rewritten: '{query}''{rewritten_query}'")
return rewritten_query
async def query_chunks(
self,
query: InterleavedContent,
@ -304,10 +334,6 @@ class VectorStoreWithIndex:
if params is None:
params = {}
# Extract configuration if provided by router
if "vector_stores_config" in params:
self.vector_stores_config = params["vector_stores_config"]
k = params.get("max_chunks", 3)
mode = params.get("mode")
score_threshold = params.get("score_threshold", 0.0)
@ -331,18 +357,9 @@ class VectorStoreWithIndex:
query_string = interleaved_content_as_str(query)
# Apply query rewriting if enabled
# Apply query rewriting if enabled and model is configured
if params.get("rewrite_query", False):
if self.vector_stores_config:
log.debug(f"VectorStoreWithIndex received config: {self.vector_stores_config}")
if hasattr(self.vector_stores_config, "default_query_expansion_model"):
log.debug(
f"Config has default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
)
else:
log.debug("No vector_stores_config found - cannot perform query rewriting")
query_string = await self._rewrite_query_for_search(query_string)
query_string = await self._rewrite_query_for_file_search(query_string)
if mode == "keyword":
return await self.index.query_keyword(query_string, k, score_threshold)
@ -359,88 +376,3 @@ class VectorStoreWithIndex:
)
else:
return await self.index.query_vector(query_vector, k, score_threshold)
async def _rewrite_query_for_search(self, query: str) -> str:
"""Rewrite the user query to improve vector search performance.
:param query: The original user query
:returns: The rewritten query optimized for vector search
"""
expansion_model = None
# Check for per-store query expansion model first
if self.vector_store.query_expansion_model:
# Parse the model string into provider_id and model_id
model_parts = self.vector_store.query_expansion_model.split("/", 1)
if len(model_parts) == 2:
expansion_model = QualifiedModel(provider_id=model_parts[0], model_id=model_parts[1])
log.debug(f"Using per-store query expansion model: {expansion_model}")
else:
log.warning(
f"Invalid query_expansion_model format: {self.vector_store.query_expansion_model}. Expected 'provider_id/model_id'"
)
# Fall back to global default if no per-store model
if not expansion_model:
if not self.vector_stores_config:
raise ValueError(
f"No vector_stores_config found and no per-store query_expansion_model! self.vector_stores_config is: {self.vector_stores_config}"
)
if not self.vector_stores_config.default_query_expansion_model:
raise ValueError(
f"No default_query_expansion_model configured and no per-store query_expansion_model! vector_stores_config: {self.vector_stores_config}, default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
)
expansion_model = self.vector_stores_config.default_query_expansion_model
log.debug(f"Using global default query expansion model: {expansion_model}")
chat_model = f"{expansion_model.provider_id}/{expansion_model.model_id}"
# Validate that the model is available and is an LLM
try:
models_response = await self.inference_api.routing_table.list_models()
except Exception as e:
raise RuntimeError(f"Failed to list available models for validation: {e}") from e
model_found = False
for model in models_response.data:
if model.identifier == chat_model:
if model.model_type != ModelType.llm:
raise ValueError(
f"Configured query expansion model '{chat_model}' is not an LLM model "
f"(found type: {model.model_type}). Query rewriting requires an LLM model."
)
model_found = True
break
if not model_found:
available_llm_models = [m.identifier for m in models_response.data if m.model_type == ModelType.llm]
raise ValueError(
f"Configured query expansion model '{chat_model}' is not available. "
f"Available LLM models: {available_llm_models}"
)
# Use the configured prompt (has a default value)
rewrite_prompt = self.vector_stores_config.query_expansion_prompt.format(query=query)
chat_request = OpenAIChatCompletionRequestWithExtraBody(
model=chat_model,
messages=[
OpenAIUserMessageParam(
role="user",
content=rewrite_prompt,
)
],
max_tokens=100,
)
try:
response = await self.inference_api.openai_chat_completion(chat_request)
except Exception as e:
raise RuntimeError(f"Failed to generate rewritten query: {e}") from e
if response.choices and len(response.choices) > 0:
rewritten_query = response.choices[0].message.content.strip()
log.info(f"Query rewritten: '{query}''{rewritten_query}'")
return rewritten_query
else:
raise RuntimeError("No response received from LLM model for query rewriting")

View file

@ -25,7 +25,6 @@ class VectorStore(Resource):
embedding_model: str
embedding_dimension: int
vector_store_name: str | None = None
query_expansion_model: str | None = None
@property
def vector_store_id(self) -> str:

View file

@ -1233,94 +1233,122 @@ async def test_embedding_config_required_model_missing(vector_io_adapter):
async def test_query_expansion_functionality(vector_io_adapter):
"""Test query expansion with per-store models, global defaults, and error validation."""
"""Test query expansion with simplified global configuration approach."""
from unittest.mock import MagicMock
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
from llama_stack.core.datatypes import QualifiedModel
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config
from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex
from llama_stack_api.models import Model, ModelType
from llama_stack_api import QueryChunksResponse
vector_io_adapter.register_vector_store = AsyncMock()
vector_io_adapter.__provider_id__ = "test_provider"
# Test 1: Per-store model usage
params = OpenAICreateVectorStoreRequestWithExtraBody(
name="test_store",
metadata={},
**{"embedding_model": "test/embedding", "query_expansion_model": "test/llama-model"},
)
await vector_io_adapter.openai_create_vector_store(params)
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.query_expansion_model == "test/llama-model"
# Test 2: Global default fallback
vector_io_adapter.register_vector_store.reset_mock()
params_no_model = OpenAICreateVectorStoreRequestWithExtraBody(
name="test_store2", metadata={}, **{"embedding_model": "test/embedding"}
)
await vector_io_adapter.openai_create_vector_store(params_no_model)
call_args2 = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args2.query_expansion_model is None
# Test query rewriting scenarios
mock_inference_api = MagicMock()
# Per-store model scenario
# Mock a simple vector store and index
mock_vector_store = MagicMock()
mock_vector_store.query_expansion_model = "test/llama-model"
mock_inference_api.routing_table.list_models = AsyncMock(
return_value=MagicMock(
data=[Model(identifier="test/llama-model", provider_id="test", model_type=ModelType.llm)]
)
)
mock_inference_api.openai_chat_completion = AsyncMock(
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="per-store expanded"))])
)
mock_vector_store.embedding_model = "test/embedding"
mock_inference_api = MagicMock()
mock_index = MagicMock()
# Create VectorStoreWithIndex with simplified constructor
vector_store_with_index = VectorStoreWithIndex(
vector_store=mock_vector_store,
index=MagicMock(),
index=mock_index,
inference_api=mock_inference_api,
vector_stores_config=VectorStoresConfig(
default_query_expansion_model=QualifiedModel(provider_id="global", model_id="default")
),
)
result = await vector_store_with_index._rewrite_query_for_search("test")
assert mock_inference_api.openai_chat_completion.call_args[0][0].model == "test/llama-model"
assert result == "per-store expanded"
# Mock the query_vector method to return a simple response
mock_response = QueryChunksResponse(chunks=[], scores=[])
mock_index.query_vector = AsyncMock(return_value=mock_response)
# Global default fallback scenario
mock_inference_api.reset_mock()
mock_vector_store.query_expansion_model = None
mock_inference_api.routing_table.list_models = AsyncMock(
return_value=MagicMock(
data=[Model(identifier="global/default", provider_id="global", model_type=ModelType.llm)]
)
# Mock embeddings generation
mock_inference_api.openai_embeddings = AsyncMock(
return_value=MagicMock(data=[MagicMock(embedding=[0.1, 0.2, 0.3])])
)
# Test 1: Query expansion with default prompt (no custom prompt configured)
mock_vector_stores_config = MagicMock()
mock_vector_stores_config.default_query_expansion_model = QualifiedModel(provider_id="test", model_id="llama")
mock_vector_stores_config.query_expansion_prompt = None # Use built-in default prompt
mock_vector_stores_config.query_expansion_max_tokens = 100 # Default value
mock_vector_stores_config.query_expansion_temperature = 0.3 # Default value
# Set global config
set_default_query_expansion_config(mock_vector_stores_config)
# Mock chat completion for query rewriting
mock_inference_api.openai_chat_completion = AsyncMock(
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="global expanded"))])
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="expanded test query"))])
)
result = await vector_store_with_index._rewrite_query_for_search("test")
assert mock_inference_api.openai_chat_completion.call_args[0][0].model == "global/default"
assert result == "global expanded"
params = {"rewrite_query": True, "max_chunks": 5}
result = await vector_store_with_index.query_chunks("test query", params)
# Test 3: Error cases
# Model not found
mock_vector_store.query_expansion_model = "missing/model"
mock_inference_api.routing_table.list_models = AsyncMock(return_value=MagicMock(data=[]))
# Verify chat completion was called for query rewriting
mock_inference_api.openai_chat_completion.assert_called_once()
chat_call_args = mock_inference_api.openai_chat_completion.call_args[0][0]
assert chat_call_args.model == "test/llama"
with pytest.raises(ValueError, match="Configured query expansion model .* is not available"):
await vector_store_with_index._rewrite_query_for_search("test")
# Verify default prompt is used (contains our built-in prompt text)
prompt_text = chat_call_args.messages[0].content
expected_prompt = DEFAULT_QUERY_EXPANSION_PROMPT.format(query="test query")
assert prompt_text == expected_prompt
# Non-LLM model
mock_vector_store.query_expansion_model = "test/embedding-model"
mock_inference_api.routing_table.list_models = AsyncMock(
return_value=MagicMock(
data=[Model(identifier="test/embedding-model", provider_id="test", model_type=ModelType.embedding)]
)
)
# Verify default inference parameters are used
assert chat_call_args.max_tokens == 100 # Default value
assert chat_call_args.temperature == 0.3 # Default value
with pytest.raises(ValueError, match="is not an LLM model.*Query rewriting requires an LLM model"):
await vector_store_with_index._rewrite_query_for_search("test")
# Verify the rest of the flow proceeded normally
mock_inference_api.openai_embeddings.assert_called_once()
mock_index.query_vector.assert_called_once()
assert result == mock_response
# Test 1b: Query expansion with custom prompt override and inference parameters
mock_inference_api.reset_mock()
mock_index.reset_mock()
mock_vector_stores_config.query_expansion_prompt = "Custom prompt for rewriting: {query}"
mock_vector_stores_config.query_expansion_max_tokens = 150
mock_vector_stores_config.query_expansion_temperature = 0.7
set_default_query_expansion_config(mock_vector_stores_config)
result = await vector_store_with_index.query_chunks("test query", params)
# Verify custom prompt and parameters are used
mock_inference_api.openai_chat_completion.assert_called_once()
chat_call_args = mock_inference_api.openai_chat_completion.call_args[0][0]
prompt_text = chat_call_args.messages[0].content
assert prompt_text == "Custom prompt for rewriting: test query"
assert "Expand this query with relevant synonyms" not in prompt_text # Default not used
# Verify custom inference parameters
assert chat_call_args.max_tokens == 150
assert chat_call_args.temperature == 0.7
# Test 2: No query expansion when no global model is configured
mock_inference_api.reset_mock()
mock_index.reset_mock()
# Clear global config
set_default_query_expansion_config(None)
params = {"rewrite_query": True, "max_chunks": 5}
result2 = await vector_store_with_index.query_chunks("test query", params)
# Verify chat completion was NOT called
mock_inference_api.openai_chat_completion.assert_not_called()
# But normal flow should still work
mock_inference_api.openai_embeddings.assert_called_once()
mock_index.query_vector.assert_called_once()
assert result2 == mock_response
# Test 3: Normal behavior without rewrite_query parameter
mock_inference_api.reset_mock()
mock_index.reset_mock()
params_no_rewrite = {"max_chunks": 5}
result3 = await vector_store_with_index.query_chunks("test query", params_no_rewrite)
# Neither chat completion nor query rewriting should be called
mock_inference_api.openai_chat_completion.assert_not_called()
mock_inference_api.openai_embeddings.assert_called_once()
mock_index.query_vector.assert_called_once()
assert result3 == mock_response