adding config to providers so that it can properly be used

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-11-19 10:23:17 -05:00
parent 61a4738a12
commit ac7cb1ba5a
18 changed files with 168 additions and 30 deletions

View file

@ -374,6 +374,13 @@ async def instantiate_provider(
method = "get_adapter_impl" method = "get_adapter_impl"
args = [config, deps] args = [config, deps]
# Add vector_stores_config for vector_io providers
if (
"vector_stores_config" in inspect.signature(getattr(module, method)).parameters
and provider_spec.api == Api.vector_io
):
args.append(run_config.vector_stores)
elif isinstance(provider_spec, AutoRoutedProviderSpec): elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl" method = "get_auto_router_impl"
@ -394,6 +401,11 @@ async def instantiate_provider(
args.append(policy) args.append(policy)
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry: if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
args.append(run_config.telemetry.enabled) args.append(run_config.telemetry.enabled)
if (
"vector_stores_config" in inspect.signature(getattr(module, method)).parameters
and provider_spec.api == Api.vector_io
):
args.append(run_config.vector_stores)
fn = getattr(module, method) fn = getattr(module, method)
impl = await fn(*args) impl = await fn(*args)

View file

@ -103,6 +103,13 @@ class VectorIORouter(VectorIO):
# Ensure params dict exists and add vector_stores_config for query rewriting # Ensure params dict exists and add vector_stores_config for query rewriting
if params is None: if params is None:
params = {} params = {}
logger.debug(f"Router vector_stores_config: {self.vector_stores_config}")
if self.vector_stores_config and hasattr(self.vector_stores_config, "default_query_expansion_model"):
logger.debug(
f"Router default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
)
params["vector_stores_config"] = self.vector_stores_config params["vector_stores_config"] = self.vector_stores_config
return await provider.query_chunks(vector_store_id, query, params) return await provider.query_chunks(vector_store_id, query, params)

View file

@ -6,16 +6,19 @@
from typing import Any from typing import Any
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api from llama_stack_api import Api
from .config import FaissVectorIOConfig from .config import FaissVectorIOConfig
async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]): async def get_provider_impl(
config: FaissVectorIOConfig, deps: dict[Api, Any], vector_stores_config: VectorStoresConfig | None = None
):
from .faiss import FaissVectorIOAdapter from .faiss import FaissVectorIOAdapter
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -14,6 +14,7 @@ import faiss # type: ignore[import-untyped]
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@ -184,10 +185,17 @@ class FaissIndex(EmbeddingIndex):
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate): class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None: def __init__(
self,
config: FaissVectorIOConfig,
inference_api: Inference,
files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None) super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.cache: dict[str, VectorStoreWithIndex] = {} self.cache: dict[str, VectorStoreWithIndex] = {}
async def initialize(self) -> None: async def initialize(self) -> None:
@ -203,6 +211,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
vector_store, vector_store,
await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier), await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
self.inference_api, self.inference_api,
self.vector_stores_config,
) )
self.cache[vector_store.identifier] = index self.cache[vector_store.identifier] = index
@ -241,6 +250,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
vector_store=vector_store, vector_store=vector_store,
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier), index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
async def list_vector_stores(self) -> list[VectorStore]: async def list_vector_stores(self) -> list[VectorStore]:
@ -274,6 +284,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
vector_store=vector_store, vector_store=vector_store,
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier), index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
self.cache[vector_store_id] = index self.cache[vector_store_id] = index
return index return index

View file

@ -6,15 +6,18 @@
from typing import Any from typing import Any
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api from llama_stack_api import Api
from .config import SQLiteVectorIOConfig from .config import SQLiteVectorIOConfig
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]): async def get_provider_impl(
config: SQLiteVectorIOConfig, deps: dict[Api, Any], vector_stores_config: VectorStoresConfig | None = None
):
from .sqlite_vec import SQLiteVecVectorIOAdapter from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -14,6 +14,7 @@ import numpy as np
import sqlite_vec # type: ignore[import-untyped] import sqlite_vec # type: ignore[import-untyped]
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@ -385,10 +386,17 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
and creates a cache of VectorStoreWithIndex instances (each wrapping a SQLiteVecIndex). and creates a cache of VectorStoreWithIndex instances (each wrapping a SQLiteVecIndex).
""" """
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None: def __init__(
self,
config,
inference_api: Inference,
files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None) super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.cache: dict[str, VectorStoreWithIndex] = {} self.cache: dict[str, VectorStoreWithIndex] = {}
self.vector_store_table = None self.vector_store_table = None
@ -403,7 +411,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
index = await SQLiteVecIndex.create( index = await SQLiteVecIndex.create(
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
) )
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api) self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store, index, self.inference_api, self.vector_stores_config
)
# Load existing OpenAI vector stores into the in-memory cache # Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores() await self.initialize_openai_vector_stores()
@ -427,7 +437,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
index = await SQLiteVecIndex.create( index = await SQLiteVecIndex.create(
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
) )
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api) self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store, index, self.inference_api, self.vector_stores_config
)
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None: async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
if vector_store_id in self.cache: if vector_store_id in self.cache:
@ -452,6 +464,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
kvstore=self.kvstore, kvstore=self.kvstore,
), ),
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
self.cache[vector_store_id] = index self.cache[vector_store_id] = index
return index return index

View file

@ -4,14 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec from llama_stack_api import Api, ProviderSpec
from .config import ChromaVectorIOConfig from .config import ChromaVectorIOConfig
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]): async def get_adapter_impl(
config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
):
from .chroma import ChromaVectorIOAdapter from .chroma import ChromaVectorIOAdapter
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -11,6 +11,7 @@ from urllib.parse import urlparse
import chromadb import chromadb
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
@ -125,11 +126,13 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig, config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
inference_api: Inference, inference_api: Inference,
files_api: Files | None, files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None: ) -> None:
super().__init__(files_api=files_api, kvstore=None) super().__init__(files_api=files_api, kvstore=None)
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}") log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.client = None self.client = None
self.cache = {} self.cache = {}
self.vector_store_table = None self.vector_store_table = None
@ -162,7 +165,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
) )
) )
self.cache[vector_store.identifier] = VectorStoreWithIndex( self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store, ChromaIndex(self.client, collection), self.inference_api vector_store, ChromaIndex(self.client, collection), self.inference_api, self.vector_stores_config
) )
async def unregister_vector_store(self, vector_store_id: str) -> None: async def unregister_vector_store(self, vector_store_id: str) -> None:
@ -207,7 +210,9 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
collection = await maybe_await(self.client.get_collection(vector_store_id)) collection = await maybe_await(self.client.get_collection(vector_store_id))
if not collection: if not collection:
raise ValueError(f"Vector DB {vector_store_id} not found in Chroma") raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
index = VectorStoreWithIndex(vector_store, ChromaIndex(self.client, collection), self.inference_api) index = VectorStoreWithIndex(
vector_store, ChromaIndex(self.client, collection), self.inference_api, self.vector_stores_config
)
self.cache[vector_store_id] = index self.cache[vector_store_id] = index
return index return index

View file

@ -4,15 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec from llama_stack_api import Api, ProviderSpec
from .config import MilvusVectorIOConfig from .config import MilvusVectorIOConfig
async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec]): async def get_adapter_impl(
config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
):
from .milvus import MilvusVectorIOAdapter from .milvus import MilvusVectorIOAdapter
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -11,6 +11,7 @@ from typing import Any
from numpy.typing import NDArray from numpy.typing import NDArray
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
@ -272,12 +273,14 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig, config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
inference_api: Inference, inference_api: Inference,
files_api: Files | None, files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None: ) -> None:
super().__init__(files_api=files_api, kvstore=None) super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.cache = {} self.cache = {}
self.client = None self.client = None
self.inference_api = inference_api self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.vector_store_table = None self.vector_store_table = None
self.metadata_collection_name = "openai_vector_stores_metadata" self.metadata_collection_name = "openai_vector_stores_metadata"
@ -298,6 +301,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
kvstore=self.kvstore, kvstore=self.kvstore,
), ),
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
self.cache[vector_store.identifier] = index self.cache[vector_store.identifier] = index
if isinstance(self.config, RemoteMilvusVectorIOConfig): if isinstance(self.config, RemoteMilvusVectorIOConfig):
@ -325,6 +329,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store=vector_store, vector_store=vector_store,
index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level), index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level),
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
self.cache[vector_store.identifier] = index self.cache[vector_store.identifier] = index
@ -347,6 +352,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store=vector_store, vector_store=vector_store,
index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore), index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore),
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
self.cache[vector_store_id] = index self.cache[vector_store_id] = index
return index return index

View file

@ -4,14 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec from llama_stack_api import Api, ProviderSpec
from .config import PGVectorVectorIOConfig from .config import PGVectorVectorIOConfig
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]): async def get_adapter_impl(
config: PGVectorVectorIOConfig,
deps: dict[Api, ProviderSpec],
vector_stores_config: VectorStoresConfig | None = None,
):
from .pgvector import PGVectorVectorIOAdapter from .pgvector import PGVectorVectorIOAdapter
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -13,6 +13,7 @@ from psycopg2 import sql
from psycopg2.extras import Json, execute_values from psycopg2.extras import Json, execute_values
from pydantic import BaseModel, TypeAdapter from pydantic import BaseModel, TypeAdapter
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
@ -330,11 +331,16 @@ class PGVectorIndex(EmbeddingIndex):
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate): class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__( def __init__(
self, config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None self,
config: PGVectorVectorIOConfig,
inference_api: Inference,
files_api: Files | None = None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None: ) -> None:
super().__init__(files_api=files_api, kvstore=None) super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.conn = None self.conn = None
self.cache = {} self.cache = {}
self.vector_store_table = None self.vector_store_table = None
@ -386,7 +392,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
kvstore=self.kvstore, kvstore=self.kvstore,
) )
await pgvector_index.initialize() await pgvector_index.initialize()
index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api) index = VectorStoreWithIndex(
vector_store,
index=pgvector_index,
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store.identifier] = index self.cache[vector_store.identifier] = index
async def shutdown(self) -> None: async def shutdown(self) -> None:
@ -413,7 +424,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
vector_store=vector_store, dimension=vector_store.embedding_dimension, conn=self.conn, kvstore=self.kvstore vector_store=vector_store, dimension=vector_store.embedding_dimension, conn=self.conn, kvstore=self.kvstore
) )
await pgvector_index.initialize() await pgvector_index.initialize()
index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api) index = VectorStoreWithIndex(
vector_store,
index=pgvector_index,
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
)
self.cache[vector_store.identifier] = index self.cache[vector_store.identifier] = index
async def unregister_vector_store(self, vector_store_id: str) -> None: async def unregister_vector_store(self, vector_store_id: str) -> None:
@ -453,7 +469,9 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
vector_store = VectorStore.model_validate_json(vector_store_data) vector_store = VectorStore.model_validate_json(vector_store_data)
index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn) index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn)
await index.initialize() await index.initialize()
self.cache[vector_store_id] = VectorStoreWithIndex(vector_store, index, self.inference_api) self.cache[vector_store_id] = VectorStoreWithIndex(
vector_store, index, self.inference_api, self.vector_stores_config
)
return self.cache[vector_store_id] return self.cache[vector_store_id]
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:

View file

@ -4,14 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec from llama_stack_api import Api, ProviderSpec
from .config import QdrantVectorIOConfig from .config import QdrantVectorIOConfig
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]): async def get_adapter_impl(
config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
):
from .qdrant import QdrantVectorIOAdapter from .qdrant import QdrantVectorIOAdapter
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -13,6 +13,7 @@ from numpy.typing import NDArray
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
from qdrant_client.models import PointStruct from qdrant_client.models import PointStruct
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
@ -152,12 +153,14 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig, config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
inference_api: Inference, inference_api: Inference,
files_api: Files | None = None, files_api: Files | None = None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None: ) -> None:
super().__init__(files_api=files_api, kvstore=None) super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.client: AsyncQdrantClient = None self.client: AsyncQdrantClient = None
self.cache = {} self.cache = {}
self.inference_api = inference_api self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.vector_store_table = None self.vector_store_table = None
self._qdrant_lock = asyncio.Lock() self._qdrant_lock = asyncio.Lock()
@ -173,7 +176,10 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
for vector_store_data in stored_vector_stores: for vector_store_data in stored_vector_stores:
vector_store = VectorStore.model_validate_json(vector_store_data) vector_store = VectorStore.model_validate_json(vector_store_data)
index = VectorStoreWithIndex( index = VectorStoreWithIndex(
vector_store, QdrantIndex(self.client, vector_store.identifier), self.inference_api vector_store,
QdrantIndex(self.client, vector_store.identifier),
self.inference_api,
self.vector_stores_config,
) )
self.cache[vector_store.identifier] = index self.cache[vector_store.identifier] = index
self.openai_vector_stores = await self._load_openai_vector_stores() self.openai_vector_stores = await self._load_openai_vector_stores()
@ -193,6 +199,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store=vector_store, vector_store=vector_store,
index=QdrantIndex(self.client, vector_store.identifier), index=QdrantIndex(self.client, vector_store.identifier),
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
self.cache[vector_store.identifier] = index self.cache[vector_store.identifier] = index
@ -224,6 +231,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store=vector_store, vector_store=vector_store,
index=QdrantIndex(client=self.client, collection_name=vector_store.identifier), index=QdrantIndex(client=self.client, collection_name=vector_store.identifier),
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
self.cache[vector_store_id] = index self.cache[vector_store_id] = index
return index return index

View file

@ -4,14 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack_api import Api, ProviderSpec from llama_stack_api import Api, ProviderSpec
from .config import WeaviateVectorIOConfig from .config import WeaviateVectorIOConfig
async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]): async def get_adapter_impl(
config: WeaviateVectorIOConfig,
deps: dict[Api, ProviderSpec],
vector_stores_config: VectorStoresConfig | None = None,
):
from .weaviate import WeaviateVectorIOAdapter from .weaviate import WeaviateVectorIOAdapter
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -12,6 +12,7 @@ from numpy.typing import NDArray
from weaviate.classes.init import Auth from weaviate.classes.init import Auth
from weaviate.classes.query import Filter, HybridFusion from weaviate.classes.query import Filter, HybridFusion
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -262,10 +263,17 @@ class WeaviateIndex(EmbeddingIndex):
class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorStoresProtocolPrivate): class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorStoresProtocolPrivate):
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None: def __init__(
self,
config: WeaviateVectorIOConfig,
inference_api: Inference,
files_api: Files | None,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None) super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.vector_stores_config = vector_stores_config
self.client_cache = {} self.client_cache = {}
self.cache = {} self.cache = {}
self.vector_store_table = None self.vector_store_table = None
@ -308,7 +316,10 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
client = self._get_client() client = self._get_client()
idx = WeaviateIndex(client=client, collection_name=vector_store.identifier, kvstore=self.kvstore) idx = WeaviateIndex(client=client, collection_name=vector_store.identifier, kvstore=self.kvstore)
self.cache[vector_store.identifier] = VectorStoreWithIndex( self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store=vector_store, index=idx, inference_api=self.inference_api vector_store=vector_store,
index=idx,
inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
# Load OpenAI vector stores metadata into cache # Load OpenAI vector stores metadata into cache
@ -334,7 +345,10 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
) )
self.cache[vector_store.identifier] = VectorStoreWithIndex( self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api vector_store,
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
self.inference_api,
self.vector_stores_config,
) )
async def unregister_vector_store(self, vector_store_id: str) -> None: async def unregister_vector_store(self, vector_store_id: str) -> None:
@ -369,6 +383,7 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
vector_store=vector_store, vector_store=vector_store,
index=WeaviateIndex(client=client, collection_name=vector_store.identifier), index=WeaviateIndex(client=client, collection_name=vector_store.identifier),
inference_api=self.inference_api, inference_api=self.inference_api,
vector_stores_config=self.vector_stores_config,
) )
self.cache[vector_store_id] = index self.cache[vector_store_id] = index
return index return index

View file

@ -613,6 +613,9 @@ class OpenAIVectorStoreMixin(ABC):
"mode": search_mode, "mode": search_mode,
"rewrite_query": rewrite_query, "rewrite_query": rewrite_query,
} }
# Add vector_stores_config if available (for query rewriting)
if hasattr(self, "vector_stores_config"):
params["vector_stores_config"] = self.vector_stores_config
# TODO: Add support for ranking_options.ranker # TODO: Add support for ranking_options.ranker
response = await self.query_chunks( response = await self.query_chunks(

View file

@ -333,6 +333,15 @@ class VectorStoreWithIndex:
# Apply query rewriting if enabled # Apply query rewriting if enabled
if params.get("rewrite_query", False): if params.get("rewrite_query", False):
if self.vector_stores_config:
log.debug(f"VectorStoreWithIndex received config: {self.vector_stores_config}")
if hasattr(self.vector_stores_config, "default_query_expansion_model"):
log.debug(
f"Config has default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
)
else:
log.debug("No vector_stores_config found - cannot perform query rewriting")
query_string = await self._rewrite_query_for_search(query_string) query_string = await self._rewrite_query_for_search(query_string)
if mode == "keyword": if mode == "keyword":
@ -358,8 +367,14 @@ class VectorStoreWithIndex:
:returns: The rewritten query optimized for vector search :returns: The rewritten query optimized for vector search
""" """
# Check if query expansion model is configured # Check if query expansion model is configured
if not self.vector_stores_config or not self.vector_stores_config.default_query_expansion_model: if not self.vector_stores_config:
raise ValueError("No default_query_expansion_model configured for query rewriting") raise ValueError(
f"No vector_stores_config found! self.vector_stores_config is: {self.vector_stores_config}"
)
if not self.vector_stores_config.default_query_expansion_model:
raise ValueError(
f"No default_query_expansion_model configured! vector_stores_config: {self.vector_stores_config}, default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
)
# Use the configured model # Use the configured model
expansion_model = self.vector_stores_config.default_query_expansion_model expansion_model = self.vector_stores_config.default_query_expansion_model