mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
adding config to providers so that it can properly be used
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
61a4738a12
commit
ac7cb1ba5a
18 changed files with 168 additions and 30 deletions
|
|
@ -374,6 +374,13 @@ async def instantiate_provider(
|
|||
method = "get_adapter_impl"
|
||||
args = [config, deps]
|
||||
|
||||
# Add vector_stores_config for vector_io providers
|
||||
if (
|
||||
"vector_stores_config" in inspect.signature(getattr(module, method)).parameters
|
||||
and provider_spec.api == Api.vector_io
|
||||
):
|
||||
args.append(run_config.vector_stores)
|
||||
|
||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
method = "get_auto_router_impl"
|
||||
|
||||
|
|
@ -394,6 +401,11 @@ async def instantiate_provider(
|
|||
args.append(policy)
|
||||
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
|
||||
args.append(run_config.telemetry.enabled)
|
||||
if (
|
||||
"vector_stores_config" in inspect.signature(getattr(module, method)).parameters
|
||||
and provider_spec.api == Api.vector_io
|
||||
):
|
||||
args.append(run_config.vector_stores)
|
||||
|
||||
fn = getattr(module, method)
|
||||
impl = await fn(*args)
|
||||
|
|
|
|||
|
|
@ -103,6 +103,13 @@ class VectorIORouter(VectorIO):
|
|||
# Ensure params dict exists and add vector_stores_config for query rewriting
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
logger.debug(f"Router vector_stores_config: {self.vector_stores_config}")
|
||||
if self.vector_stores_config and hasattr(self.vector_stores_config, "default_query_expansion_model"):
|
||||
logger.debug(
|
||||
f"Router default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
|
||||
)
|
||||
|
||||
params["vector_stores_config"] = self.vector_stores_config
|
||||
|
||||
return await provider.query_chunks(vector_store_id, query, params)
|
||||
|
|
|
|||
|
|
@ -6,16 +6,19 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack_api import Api
|
||||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
|
||||
async def get_provider_impl(
|
||||
config: FaissVectorIOConfig, deps: dict[Api, Any], vector_stores_config: VectorStoresConfig | None = None
|
||||
):
|
||||
from .faiss import FaissVectorIOAdapter
|
||||
|
||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import faiss # type: ignore[import-untyped]
|
|||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
|
|
@ -184,10 +185,17 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
config: FaissVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
files_api: Files | None,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.vector_stores_config = vector_stores_config
|
||||
self.cache: dict[str, VectorStoreWithIndex] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
@ -203,6 +211,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
|
|||
vector_store,
|
||||
await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
||||
self.inference_api,
|
||||
self.vector_stores_config,
|
||||
)
|
||||
self.cache[vector_store.identifier] = index
|
||||
|
||||
|
|
@ -241,6 +250,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
|
|||
vector_store=vector_store,
|
||||
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
||||
inference_api=self.inference_api,
|
||||
vector_stores_config=self.vector_stores_config,
|
||||
)
|
||||
|
||||
async def list_vector_stores(self) -> list[VectorStore]:
|
||||
|
|
@ -274,6 +284,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
|
|||
vector_store=vector_store,
|
||||
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
||||
inference_api=self.inference_api,
|
||||
vector_stores_config=self.vector_stores_config,
|
||||
)
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
|
|
|||
|
|
@ -6,15 +6,18 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack_api import Api
|
||||
|
||||
from .config import SQLiteVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
|
||||
async def get_provider_impl(
|
||||
config: SQLiteVectorIOConfig, deps: dict[Api, Any], vector_stores_config: VectorStoresConfig | None = None
|
||||
):
|
||||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||
|
||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import numpy as np
|
|||
import sqlite_vec # type: ignore[import-untyped]
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
|
|
@ -385,10 +386,17 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
|
|||
and creates a cache of VectorStoreWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||
"""
|
||||
|
||||
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
inference_api: Inference,
|
||||
files_api: Files | None,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.vector_stores_config = vector_stores_config
|
||||
self.cache: dict[str, VectorStoreWithIndex] = {}
|
||||
self.vector_store_table = None
|
||||
|
||||
|
|
@ -403,7 +411,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
|
|||
index = await SQLiteVecIndex.create(
|
||||
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
|
||||
)
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||
vector_store, index, self.inference_api, self.vector_stores_config
|
||||
)
|
||||
|
||||
# Load existing OpenAI vector stores into the in-memory cache
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
|
@ -427,7 +437,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
|
|||
index = await SQLiteVecIndex.create(
|
||||
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
|
||||
)
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||
vector_store, index, self.inference_api, self.vector_stores_config
|
||||
)
|
||||
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
|
||||
if vector_store_id in self.cache:
|
||||
|
|
@ -452,6 +464,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
|
|||
kvstore=self.kvstore,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
vector_stores_config=self.vector_stores_config,
|
||||
)
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
|
|
|||
|
|
@ -4,14 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack_api import Api, ProviderSpec
|
||||
|
||||
from .config import ChromaVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
async def get_adapter_impl(
|
||||
config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
|
||||
):
|
||||
from .chroma import ChromaVectorIOAdapter
|
||||
|
||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from urllib.parse import urlparse
|
|||
import chromadb
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||
|
|
@ -125,11 +126,13 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
files_api: Files | None,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.vector_stores_config = vector_stores_config
|
||||
self.client = None
|
||||
self.cache = {}
|
||||
self.vector_store_table = None
|
||||
|
|
@ -162,7 +165,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
)
|
||||
)
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||
vector_store, ChromaIndex(self.client, collection), self.inference_api
|
||||
vector_store, ChromaIndex(self.client, collection), self.inference_api, self.vector_stores_config
|
||||
)
|
||||
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
|
|
@ -207,7 +210,9 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
collection = await maybe_await(self.client.get_collection(vector_store_id))
|
||||
if not collection:
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
|
||||
index = VectorStoreWithIndex(vector_store, ChromaIndex(self.client, collection), self.inference_api)
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store, ChromaIndex(self.client, collection), self.inference_api, self.vector_stores_config
|
||||
)
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
||||
|
|
|
|||
|
|
@ -4,15 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack_api import Api, ProviderSpec
|
||||
|
||||
from .config import MilvusVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
async def get_adapter_impl(
|
||||
config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
|
||||
):
|
||||
from .milvus import MilvusVectorIOAdapter
|
||||
|
||||
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from typing import Any
|
|||
from numpy.typing import NDArray
|
||||
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
|
||||
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
|
|
@ -272,12 +273,14 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
files_api: Files | None,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.cache = {}
|
||||
self.client = None
|
||||
self.inference_api = inference_api
|
||||
self.vector_stores_config = vector_stores_config
|
||||
self.vector_store_table = None
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
|
|
@ -298,6 +301,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
kvstore=self.kvstore,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
vector_stores_config=self.vector_stores_config,
|
||||
)
|
||||
self.cache[vector_store.identifier] = index
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
|
|
@ -325,6 +329,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
vector_store=vector_store,
|
||||
index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level),
|
||||
inference_api=self.inference_api,
|
||||
vector_stores_config=self.vector_stores_config,
|
||||
)
|
||||
|
||||
self.cache[vector_store.identifier] = index
|
||||
|
|
@ -347,6 +352,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
vector_store=vector_store,
|
||||
index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore),
|
||||
inference_api=self.inference_api,
|
||||
vector_stores_config=self.vector_stores_config,
|
||||
)
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
|
|
|||
|
|
@ -4,14 +4,19 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack_api import Api, ProviderSpec
|
||||
|
||||
from .config import PGVectorVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
async def get_adapter_impl(
|
||||
config: PGVectorVectorIOConfig,
|
||||
deps: dict[Api, ProviderSpec],
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
):
|
||||
from .pgvector import PGVectorVectorIOAdapter
|
||||
|
||||
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from psycopg2 import sql
|
|||
from psycopg2.extras import Json, execute_values
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
|
|
@ -330,11 +331,16 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
def __init__(
|
||||
self, config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None
|
||||
self,
|
||||
config: PGVectorVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
files_api: Files | None = None,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.vector_stores_config = vector_stores_config
|
||||
self.conn = None
|
||||
self.cache = {}
|
||||
self.vector_store_table = None
|
||||
|
|
@ -386,7 +392,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
|||
kvstore=self.kvstore,
|
||||
)
|
||||
await pgvector_index.initialize()
|
||||
index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api)
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store,
|
||||
index=pgvector_index,
|
||||
inference_api=self.inference_api,
|
||||
vector_stores_config=self.vector_stores_config,
|
||||
)
|
||||
self.cache[vector_store.identifier] = index
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
|
|
@ -413,7 +424,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
|||
vector_store=vector_store, dimension=vector_store.embedding_dimension, conn=self.conn, kvstore=self.kvstore
|
||||
)
|
||||
await pgvector_index.initialize()
|
||||
index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api)
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store,
|
||||
index=pgvector_index,
|
||||
inference_api=self.inference_api,
|
||||
vector_stores_config=self.vector_stores_config,
|
||||
)
|
||||
self.cache[vector_store.identifier] = index
|
||||
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
|
|
@ -453,7 +469,9 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
|||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn)
|
||||
await index.initialize()
|
||||
self.cache[vector_store_id] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
||||
self.cache[vector_store_id] = VectorStoreWithIndex(
|
||||
vector_store, index, self.inference_api, self.vector_stores_config
|
||||
)
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
|
|
|
|||
|
|
@ -4,14 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack_api import Api, ProviderSpec
|
||||
|
||||
from .config import QdrantVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
async def get_adapter_impl(
|
||||
config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
|
||||
):
|
||||
from .qdrant import QdrantVectorIOAdapter
|
||||
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from numpy.typing import NDArray
|
|||
from qdrant_client import AsyncQdrantClient, models
|
||||
from qdrant_client.models import PointStruct
|
||||
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
|
|
@ -152,12 +153,14 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
files_api: Files | None = None,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.client: AsyncQdrantClient = None
|
||||
self.cache = {}
|
||||
self.inference_api = inference_api
|
||||
self.vector_stores_config = vector_stores_config
|
||||
self.vector_store_table = None
|
||||
self._qdrant_lock = asyncio.Lock()
|
||||
|
||||
|
|
@ -173,7 +176,10 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
for vector_store_data in stored_vector_stores:
|
||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store, QdrantIndex(self.client, vector_store.identifier), self.inference_api
|
||||
vector_store,
|
||||
QdrantIndex(self.client, vector_store.identifier),
|
||||
self.inference_api,
|
||||
self.vector_stores_config,
|
||||
)
|
||||
self.cache[vector_store.identifier] = index
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
|
@ -193,6 +199,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
vector_store=vector_store,
|
||||
index=QdrantIndex(self.client, vector_store.identifier),
|
||||
inference_api=self.inference_api,
|
||||
vector_stores_config=self.vector_stores_config,
|
||||
)
|
||||
|
||||
self.cache[vector_store.identifier] = index
|
||||
|
|
@ -224,6 +231,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
vector_store=vector_store,
|
||||
index=QdrantIndex(client=self.client, collection_name=vector_store.identifier),
|
||||
inference_api=self.inference_api,
|
||||
vector_stores_config=self.vector_stores_config,
|
||||
)
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
|
|
|||
|
|
@ -4,14 +4,19 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack_api import Api, ProviderSpec
|
||||
|
||||
from .config import WeaviateVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
async def get_adapter_impl(
|
||||
config: WeaviateVectorIOConfig,
|
||||
deps: dict[Api, ProviderSpec],
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
):
|
||||
from .weaviate import WeaviateVectorIOAdapter
|
||||
|
||||
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from numpy.typing import NDArray
|
|||
from weaviate.classes.init import Auth
|
||||
from weaviate.classes.query import Filter, HybridFusion
|
||||
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -262,10 +263,17 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorStoresProtocolPrivate):
|
||||
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
config: WeaviateVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
files_api: Files | None,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.vector_stores_config = vector_stores_config
|
||||
self.client_cache = {}
|
||||
self.cache = {}
|
||||
self.vector_store_table = None
|
||||
|
|
@ -308,7 +316,10 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
|
|||
client = self._get_client()
|
||||
idx = WeaviateIndex(client=client, collection_name=vector_store.identifier, kvstore=self.kvstore)
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||
vector_store=vector_store, index=idx, inference_api=self.inference_api
|
||||
vector_store=vector_store,
|
||||
index=idx,
|
||||
inference_api=self.inference_api,
|
||||
vector_stores_config=self.vector_stores_config,
|
||||
)
|
||||
|
||||
# Load OpenAI vector stores metadata into cache
|
||||
|
|
@ -334,7 +345,10 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
|
|||
)
|
||||
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||
vector_store, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api
|
||||
vector_store,
|
||||
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
|
||||
self.inference_api,
|
||||
self.vector_stores_config,
|
||||
)
|
||||
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
|
|
@ -369,6 +383,7 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
|
|||
vector_store=vector_store,
|
||||
index=WeaviateIndex(client=client, collection_name=vector_store.identifier),
|
||||
inference_api=self.inference_api,
|
||||
vector_stores_config=self.vector_stores_config,
|
||||
)
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
|
|
|||
|
|
@ -613,6 +613,9 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
"mode": search_mode,
|
||||
"rewrite_query": rewrite_query,
|
||||
}
|
||||
# Add vector_stores_config if available (for query rewriting)
|
||||
if hasattr(self, "vector_stores_config"):
|
||||
params["vector_stores_config"] = self.vector_stores_config
|
||||
# TODO: Add support for ranking_options.ranker
|
||||
|
||||
response = await self.query_chunks(
|
||||
|
|
|
|||
|
|
@ -333,6 +333,15 @@ class VectorStoreWithIndex:
|
|||
|
||||
# Apply query rewriting if enabled
|
||||
if params.get("rewrite_query", False):
|
||||
if self.vector_stores_config:
|
||||
log.debug(f"VectorStoreWithIndex received config: {self.vector_stores_config}")
|
||||
if hasattr(self.vector_stores_config, "default_query_expansion_model"):
|
||||
log.debug(
|
||||
f"Config has default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
|
||||
)
|
||||
else:
|
||||
log.debug("No vector_stores_config found - cannot perform query rewriting")
|
||||
|
||||
query_string = await self._rewrite_query_for_search(query_string)
|
||||
|
||||
if mode == "keyword":
|
||||
|
|
@ -358,8 +367,14 @@ class VectorStoreWithIndex:
|
|||
:returns: The rewritten query optimized for vector search
|
||||
"""
|
||||
# Check if query expansion model is configured
|
||||
if not self.vector_stores_config or not self.vector_stores_config.default_query_expansion_model:
|
||||
raise ValueError("No default_query_expansion_model configured for query rewriting")
|
||||
if not self.vector_stores_config:
|
||||
raise ValueError(
|
||||
f"No vector_stores_config found! self.vector_stores_config is: {self.vector_stores_config}"
|
||||
)
|
||||
if not self.vector_stores_config.default_query_expansion_model:
|
||||
raise ValueError(
|
||||
f"No default_query_expansion_model configured! vector_stores_config: {self.vector_stores_config}, default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
|
||||
)
|
||||
|
||||
# Use the configured model
|
||||
expansion_model = self.vector_stores_config.default_query_expansion_model
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue