adding config to providers so that it can properly be used

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-11-19 10:23:17 -05:00
parent 61a4738a12
commit ac7cb1ba5a
18 changed files with 168 additions and 30 deletions

View file

@ -374,6 +374,13 @@ async def instantiate_provider(
method = "get_adapter_impl"
args = [config, deps]
# Add vector_stores_config for vector_io providers
if (
"vector_stores_config" in inspect.signature(getattr(module, method)).parameters
and provider_spec.api == Api.vector_io
):
args.append(run_config.vector_stores)
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
@ -394,6 +401,11 @@ async def instantiate_provider(
args.append(policy)
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
args.append(run_config.telemetry.enabled)
if (
"vector_stores_config" in inspect.signature(getattr(module, method)).parameters
and provider_spec.api == Api.vector_io
):
args.append(run_config.vector_stores)
fn = getattr(module, method)
impl = await fn(*args)

View file

@ -103,6 +103,13 @@ class VectorIORouter(VectorIO):
# Ensure params dict exists and add vector_stores_config for query rewriting
if params is None:
params = {}
logger.debug(f"Router vector_stores_config: {self.vector_stores_config}")
if self.vector_stores_config and hasattr(self.vector_stores_config, "default_query_expansion_model"):
logger.debug(
f"Router default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
)
params["vector_stores_config"] = self.vector_stores_config
return await provider.query_chunks(vector_store_id, query, params)

View file

@ -6,16 +6,19 @@
from typing import Any
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api
from .config import FaissVectorIOConfig
async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
async def get_provider_impl(
config: FaissVectorIOConfig, deps: dict[Api, Any], vector_stores_config: VectorStoresConfig | None = None
):
from .faiss import FaissVectorIOAdapter
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
await impl.initialize()
return impl

View file

@ -14,6 +14,7 @@ import faiss # type: ignore[import-untyped]
import numpy as np
from numpy.typing import NDArray
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@ -184,10 +185,17 @@ class FaissIndex(EmbeddingIndex):
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
def __init__(
self,
config: FaissVectorIOConfig,
inference_api: Inference,
files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.cache: dict[str, VectorStoreWithIndex] = {}
async def initialize(self) -> None:
@ -203,6 +211,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
vector_store,
await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
self.inference_api,
self.vector_stores_config,
)
self.cache[vector_store.identifier] = index
@ -241,6 +250,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
vector_store=vector_store,
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
async def list_vector_stores(self) -> list[VectorStore]:
@ -274,6 +284,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
vector_store=vector_store,
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store_id] = index
return index

View file

@ -6,15 +6,18 @@
from typing import Any
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api
from .config import SQLiteVectorIOConfig
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
async def get_provider_impl(
config: SQLiteVectorIOConfig, deps: dict[Api, Any], vector_stores_config: VectorStoresConfig | None = None
):
from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
await impl.initialize()
return impl

View file

@ -14,6 +14,7 @@ import numpy as np
import sqlite_vec # type: ignore[import-untyped]
from numpy.typing import NDArray
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@ -385,10 +386,17 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
and creates a cache of VectorStoreWithIndex instances (each wrapping a SQLiteVecIndex).
"""
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
def __init__(
self,
config,
inference_api: Inference,
files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.cache: dict[str, VectorStoreWithIndex] = {}
self.vector_store_table = None
@ -403,7 +411,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
index = await SQLiteVecIndex.create(
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
)
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store, index, self.inference_api, self.vector_stores_config
)
# Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores()
@ -427,7 +437,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
index = await SQLiteVecIndex.create(
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
)
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store, index, self.inference_api, self.vector_stores_config
)
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
if vector_store_id in self.cache:
@ -452,6 +464,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
kvstore=self.kvstore,
),
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store_id] = index
return index

View file

@ -4,14 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec
from .config import ChromaVectorIOConfig
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]):
async def get_adapter_impl(
config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
):
from .chroma import ChromaVectorIOAdapter
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
await impl.initialize()
return impl

View file

@ -11,6 +11,7 @@ from urllib.parse import urlparse
import chromadb
from numpy.typing import NDArray
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
@ -125,11 +126,13 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
inference_api: Inference,
files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
self.config = config
self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.client = None
self.cache = {}
self.vector_store_table = None
@ -162,7 +165,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
)
)
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store, ChromaIndex(self.client, collection), self.inference_api
vector_store, ChromaIndex(self.client, collection), self.inference_api, self.vector_stores_config
)
async def unregister_vector_store(self, vector_store_id: str) -> None:
@ -207,7 +210,9 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
collection = await maybe_await(self.client.get_collection(vector_store_id))
if not collection:
raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
index = VectorStoreWithIndex(vector_store, ChromaIndex(self.client, collection), self.inference_api)
index = VectorStoreWithIndex(
vector_store, ChromaIndex(self.client, collection), self.inference_api, self.vector_stores_config
)
self.cache[vector_store_id] = index
return index

View file

@ -4,15 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec
from .config import MilvusVectorIOConfig
async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec]):
async def get_adapter_impl(
config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
):
from .milvus import MilvusVectorIOAdapter
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
await impl.initialize()
return impl

View file

@ -11,6 +11,7 @@ from typing import Any
from numpy.typing import NDArray
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
@ -272,12 +273,14 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
inference_api: Inference,
files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.cache = {}
self.client = None
self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.vector_store_table = None
self.metadata_collection_name = "openai_vector_stores_metadata"
@ -298,6 +301,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
kvstore=self.kvstore,
),
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store.identifier] = index
if isinstance(self.config, RemoteMilvusVectorIOConfig):
@ -325,6 +329,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store=vector_store,
index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level),
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store.identifier] = index
@ -347,6 +352,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store=vector_store,
index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore),
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store_id] = index
return index

View file

@ -4,14 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec
from .config import PGVectorVectorIOConfig
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]):
async def get_adapter_impl(
config: PGVectorVectorIOConfig,
deps: dict[Api, ProviderSpec],
vector_stores_config: VectorStoresConfig | None = None,
):
from .pgvector import PGVectorVectorIOAdapter
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
await impl.initialize()
return impl

View file

@ -13,6 +13,7 @@ from psycopg2 import sql
from psycopg2.extras import Json, execute_values
from pydantic import BaseModel, TypeAdapter
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
@ -330,11 +331,16 @@ class PGVectorIndex(EmbeddingIndex):
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__(
self, config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None
self,
config: PGVectorVectorIOConfig,
inference_api: Inference,
files_api: Files | None = None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.conn = None
self.cache = {}
self.vector_store_table = None
@ -386,7 +392,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
kvstore=self.kvstore,
)
await pgvector_index.initialize()
index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api)
index = VectorStoreWithIndex(
vector_store,
index=pgvector_index,
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store.identifier] = index
async def shutdown(self) -> None:
@ -413,7 +424,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
vector_store=vector_store, dimension=vector_store.embedding_dimension, conn=self.conn, kvstore=self.kvstore
)
await pgvector_index.initialize()
index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api)
index = VectorStoreWithIndex(
vector_store,
index=pgvector_index,
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store.identifier] = index
async def unregister_vector_store(self, vector_store_id: str) -> None:
@ -453,7 +469,9 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
vector_store = VectorStore.model_validate_json(vector_store_data)
index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn)
await index.initialize()
self.cache[vector_store_id] = VectorStoreWithIndex(vector_store, index, self.inference_api)
self.cache[vector_store_id] = VectorStoreWithIndex(
vector_store, index, self.inference_api, self.vector_stores_config
)
return self.cache[vector_store_id]
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:

View file

@ -4,14 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec
from .config import QdrantVectorIOConfig
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
async def get_adapter_impl(
config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
):
from .qdrant import QdrantVectorIOAdapter
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
await impl.initialize()
return impl

View file

@ -13,6 +13,7 @@ from numpy.typing import NDArray
from qdrant_client import AsyncQdrantClient, models
from qdrant_client.models import PointStruct
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
@ -152,12 +153,14 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
inference_api: Inference,
files_api: Files | None = None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.client: AsyncQdrantClient = None
self.cache = {}
self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.vector_store_table = None
self._qdrant_lock = asyncio.Lock()
@ -173,7 +176,10 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
for vector_store_data in stored_vector_stores:
vector_store = VectorStore.model_validate_json(vector_store_data)
index = VectorStoreWithIndex(
vector_store, QdrantIndex(self.client, vector_store.identifier), self.inference_api
vector_store,
QdrantIndex(self.client, vector_store.identifier),
self.inference_api,
self.vector_stores_config,
)
self.cache[vector_store.identifier] = index
self.openai_vector_stores = await self._load_openai_vector_stores()
@ -193,6 +199,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store=vector_store,
index=QdrantIndex(self.client, vector_store.identifier),
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store.identifier] = index
@ -224,6 +231,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store=vector_store,
index=QdrantIndex(client=self.client, collection_name=vector_store.identifier),
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store_id] = index
return index

View file

@ -4,14 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec
from .config import WeaviateVectorIOConfig
async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]):
async def get_adapter_impl(
config: WeaviateVectorIOConfig,
deps: dict[Api, ProviderSpec],
vector_stores_config: VectorStoresConfig | None = None,
):
from .weaviate import WeaviateVectorIOAdapter
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
await impl.initialize()
return impl

View file

@ -12,6 +12,7 @@ from numpy.typing import NDArray
from weaviate.classes.init import Auth
from weaviate.classes.query import Filter, HybridFusion
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
@ -262,10 +263,17 @@ class WeaviateIndex(EmbeddingIndex):
class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorStoresProtocolPrivate):
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
def __init__(
self,
config: WeaviateVectorIOConfig,
inference_api: Inference,
files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.client_cache = {}
self.cache = {}
self.vector_store_table = None
@ -308,7 +316,10 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
client = self._get_client()
idx = WeaviateIndex(client=client, collection_name=vector_store.identifier, kvstore=self.kvstore)
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store=vector_store, index=idx, inference_api=self.inference_api
vector_store=vector_store,
index=idx,
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
# Load OpenAI vector stores metadata into cache
@ -334,7 +345,10 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
)
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api
vector_store,
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
self.inference_api,
self.vector_stores_config,
)
async def unregister_vector_store(self, vector_store_id: str) -> None:
@ -369,6 +383,7 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
vector_store=vector_store,
index=WeaviateIndex(client=client, collection_name=vector_store.identifier),
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store_id] = index
return index

View file

@ -613,6 +613,9 @@ class OpenAIVectorStoreMixin(ABC):
"mode": search_mode,
"rewrite_query": rewrite_query,
}
# Add vector_stores_config if available (for query rewriting)
if hasattr(self, "vector_stores_config"):
params["vector_stores_config"] = self.vector_stores_config
# TODO: Add support for ranking_options.ranker
response = await self.query_chunks(

View file

@ -333,6 +333,15 @@ class VectorStoreWithIndex:
# Apply query rewriting if enabled
if params.get("rewrite_query", False):
if self.vector_stores_config:
log.debug(f"VectorStoreWithIndex received config: {self.vector_stores_config}")
if hasattr(self.vector_stores_config, "default_query_expansion_model"):
log.debug(
f"Config has default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
)
else:
log.debug("No vector_stores_config found - cannot perform query rewriting")
query_string = await self._rewrite_query_for_search(query_string)
if mode == "keyword":
@ -358,8 +367,14 @@ class VectorStoreWithIndex:
:returns: The rewritten query optimized for vector search
"""
# Check if query expansion model is configured
if not self.vector_stores_config or not self.vector_stores_config.default_query_expansion_model:
raise ValueError("No default_query_expansion_model configured for query rewriting")
if not self.vector_stores_config:
raise ValueError(
f"No vector_stores_config found! self.vector_stores_config is: {self.vector_stores_config}"
)
if not self.vector_stores_config.default_query_expansion_model:
raise ValueError(
f"No default_query_expansion_model configured! vector_stores_config: {self.vector_stores_config}, default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
)
# Use the configured model
expansion_model = self.vector_stores_config.default_query_expansion_model