diff --git a/llama_stack/providers/inline/memory/faiss/__init__.py b/llama_stack/providers/inline/memory/faiss/__init__.py index 16c383be3..2d7ede3b1 100644 --- a/llama_stack/providers/inline/memory/faiss/__init__.py +++ b/llama_stack/providers/inline/memory/faiss/__init__.py @@ -4,16 +4,19 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec from .config import FaissImplConfig -async def get_provider_impl(config: FaissImplConfig, _deps): +async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]): from .faiss import FaissMemoryImpl assert isinstance( config, FaissImplConfig ), f"Unexpected config type: {type(config)}" - impl = FaissMemoryImpl(config) + impl = FaissMemoryImpl(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/memory/faiss/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py index 78de13120..cb090c870 100644 --- a/llama_stack/providers/inline/memory/faiss/faiss.py +++ b/llama_stack/providers/inline/memory/faiss/faiss.py @@ -19,13 +19,12 @@ from numpy.typing import NDArray from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.memory.vector_store import ( - ALL_MINILM_L6_V2_DIMENSION, - BankWithIndex, EmbeddingIndex, + InferenceEmbeddingMixin, ) from .config import FaissImplConfig @@ -95,6 +94,15 @@ class FaissIndex(EmbeddingIndex): await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}") async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + # Add dimension check + embedding_dim = ( + embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0] + ) + if embedding_dim != self.index.d: + raise ValueError( + f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}" + ) + indexlen = len(self.id_by_index) for i, chunk in enumerate(chunks): self.chunk_by_index[indexlen + i] = chunk @@ -123,9 +131,10 @@ class FaissIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) -class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): - def __init__(self, config: FaissImplConfig) -> None: +class FaissMemoryImpl(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate): + def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None: self.config = config + self.inference_api = inference_api self.cache = {} self.kvstore = None @@ -138,10 +147,10 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): for bank_data in stored_banks: bank = VectorMemoryBank.model_validate_json(bank_data) - index = BankWithIndex( - bank=bank, - index=await FaissIndex.create( - ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier + index = self._create_bank_with_index( + bank, + await FaissIndex.create( + bank.embedding_dimension, self.kvstore, bank.identifier ), ) self.cache[bank.identifier] = index @@ -166,13 +175,12 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): ) # Store in cache - index = BankWithIndex( - bank=memory_bank, - index=await FaissIndex.create( - ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier + self.cache[memory_bank.identifier] = self._create_bank_with_index( + memory_bank, + await FaissIndex.create( + memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier ), ) - self.cache[memory_bank.identifier] = index async def list_memory_banks(self) -> List[MemoryBank]: return [i.bank for i in self.cache.values()] diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index ff0926108..8bc3d2e7b 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -39,6 +39,7 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.inline.memory.faiss", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", deprecation_warning="Please use the `inline::faiss` provider instead.", + api_dependencies=[Api.inference], ), InlineProviderSpec( api=Api.memory, @@ -46,6 +47,7 @@ def available_providers() -> List[ProviderSpec]: pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], module="llama_stack.providers.inline.memory.faiss", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", + api_dependencies=[Api.inference], ), remote_provider_spec( Api.memory, @@ -55,6 +57,7 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.remote.memory.chroma", config_class="llama_stack.distribution.datatypes.RemoteProviderConfig", ), + api_dependencies=[Api.inference], ), remote_provider_spec( Api.memory, @@ -64,6 +67,7 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.remote.memory.pgvector", config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig", ), + api_dependencies=[Api.inference], ), remote_provider_spec( Api.memory, @@ -74,6 +78,7 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig", provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData", ), + api_dependencies=[Api.inference], ), remote_provider_spec( api=Api.memory, @@ -83,6 +88,7 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.remote.memory.sample", config_class="llama_stack.providers.remote.memory.sample.SampleConfig", ), + api_dependencies=[], ), remote_provider_spec( Api.memory, @@ -92,5 +98,6 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.remote.memory.qdrant", config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig", ), + api_dependencies=[Api.inference], ), ] diff --git a/llama_stack/providers/remote/memory/chroma/__init__.py b/llama_stack/providers/remote/memory/chroma/__init__.py index dfd5c5696..936fabba1 100644 --- a/llama_stack/providers/remote/memory/chroma/__init__.py +++ b/llama_stack/providers/remote/memory/chroma/__init__.py @@ -4,12 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + from llama_stack.distribution.datatypes import RemoteProviderConfig +from llama_stack.providers.datatypes import Api, ProviderSpec -async def get_adapter_impl(config: RemoteProviderConfig, _deps): +async def get_adapter_impl(config: RemoteProviderConfig, deps: Dict[Api, ProviderSpec]): from .chroma import ChromaMemoryAdapter - impl = ChromaMemoryAdapter(config.url) + impl = ChromaMemoryAdapter(config.url, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index 207f6b54d..f2b48a3be 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -20,6 +20,7 @@ from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, + InferenceEmbeddingMixin, ) log = logging.getLogger(__name__) @@ -71,8 +72,8 @@ class ChromaIndex(EmbeddingIndex): await self.client.delete_collection(self.collection.name) -class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): - def __init__(self, url: str) -> None: +class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate): + def __init__(self, url: str, inference_api: Api.inference) -> None: log.info(f"Initializing ChromaMemoryAdapter with url: {url}") url = url.rstrip("/") parsed = urlparse(url) @@ -82,6 +83,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): self.host = parsed.hostname self.port = parsed.port + self.inference_api = inference_api self.client = None self.cache = {} @@ -109,10 +111,9 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): name=memory_bank.identifier, metadata={"bank": memory_bank.model_dump_json()}, ) - bank_index = BankWithIndex( - bank=memory_bank, index=ChromaIndex(self.client, collection) + self.cache[memory_bank.identifier] = self._create_bank_with_index( + memory_bank, ChromaIndex(self.client, collection) ) - self.cache[memory_bank.identifier] = bank_index async def list_memory_banks(self) -> List[MemoryBank]: collections = await self.client.list_collections() @@ -124,11 +125,10 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): log.exception(f"Failed to parse bank: {collection.metadata}") continue - index = BankWithIndex( - bank=bank, - index=ChromaIndex(self.client, collection), + self.cache[bank.identifier] = self._create_bank_with_index( + bank, + ChromaIndex(self.client, collection), ) - self.cache[bank.identifier] = index return [i.bank for i in self.cache.values()] @@ -166,6 +166,6 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): collection = await self.client.get_collection(bank_id) if not collection: raise ValueError(f"Bank {bank_id} not found in Chroma") - index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection)) + index = self._create_bank_with_index(bank, ChromaIndex(self.client, collection)) self.cache[bank_id] = index return index diff --git a/llama_stack/providers/remote/memory/pgvector/__init__.py b/llama_stack/providers/remote/memory/pgvector/__init__.py index 4ac30452f..b4620cae0 100644 --- a/llama_stack/providers/remote/memory/pgvector/__init__.py +++ b/llama_stack/providers/remote/memory/pgvector/__init__.py @@ -4,12 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + from .config import PGVectorConfig -async def get_adapter_impl(config: PGVectorConfig, _deps): +async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]): from .pgvector import PGVectorMemoryAdapter - impl = PGVectorMemoryAdapter(config) + impl = PGVectorMemoryAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index d77de7b41..18d732534 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -16,11 +16,12 @@ from pydantic import BaseModel, parse_obj_as from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate + from llama_stack.providers.utils.memory.vector_store import ( - ALL_MINILM_L6_V2_DIMENSION, BankWithIndex, EmbeddingIndex, + InferenceEmbeddingMixin, ) from .config import PGVectorConfig @@ -119,9 +120,12 @@ class PGVectorIndex(EmbeddingIndex): self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}") -class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): - def __init__(self, config: PGVectorConfig) -> None: +class PGVectorMemoryAdapter( + InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate +): + def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None: self.config = config + self.inference_api = inference_api self.cursor = None self.conn = None self.cache = {} @@ -160,27 +164,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): async def shutdown(self) -> None: pass - async def register_memory_bank( - self, - memory_bank: MemoryBank, - ) -> None: + async def register_memory_bank(self, memory_bank: MemoryBank) -> None: assert ( memory_bank.memory_bank_type == MemoryBankType.vector.value ), f"Only vector banks are supported {memory_bank.memory_bank_type}" - upsert_models( - self.cursor, - [ - (memory_bank.identifier, memory_bank), - ], + upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)]) + index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor) + self.cache[memory_bank.identifier] = self._create_bank_with_index( + memory_bank, index ) - index = BankWithIndex( - bank=memory_bank, - index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), - ) - self.cache[memory_bank.identifier] = index - async def unregister_memory_bank(self, memory_bank_id: str) -> None: await self.cache[memory_bank_id].index.delete() del self.cache[memory_bank_id] @@ -189,9 +183,9 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): banks = load_models(self.cursor, VectorMemoryBank) for bank in banks: if bank.identifier not in self.cache: - index = BankWithIndex( - bank=bank, - index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), + index = self._create_bank_with_index( + bank, + PGVectorIndex(bank, bank.embedding_dimension, self.cursor), ) self.cache[bank.identifier] = index return banks @@ -214,14 +208,13 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): index = await self._get_and_cache_bank_index(bank_id) return await index.query_documents(query, params) + self.inference_api = inference_api + async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex: if bank_id in self.cache: return self.cache[bank_id] bank = await self.memory_bank_store.get_memory_bank(bank_id) - index = BankWithIndex( - bank=bank, - index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), - ) - self.cache[bank_id] = index - return index + index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor) + self.cache[bank_id] = self._create_bank_with_index(bank, index) + return self.cache[bank_id] diff --git a/llama_stack/providers/remote/memory/qdrant/__init__.py b/llama_stack/providers/remote/memory/qdrant/__init__.py index 9f54babad..54605fcf9 100644 --- a/llama_stack/providers/remote/memory/qdrant/__init__.py +++ b/llama_stack/providers/remote/memory/qdrant/__init__.py @@ -4,12 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + from .config import QdrantConfig -async def get_adapter_impl(config: QdrantConfig, _deps): +async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]): from .qdrant import QdrantVectorMemoryAdapter - impl = QdrantVectorMemoryAdapter(config) + impl = QdrantVectorMemoryAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/memory/qdrant/qdrant.py b/llama_stack/providers/remote/memory/qdrant/qdrant.py index be370eec9..f2c36438d 100644 --- a/llama_stack/providers/remote/memory/qdrant/qdrant.py +++ b/llama_stack/providers/remote/memory/qdrant/qdrant.py @@ -21,6 +21,7 @@ from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, + InferenceEmbeddingMixin, ) log = logging.getLogger(__name__) @@ -100,11 +101,14 @@ class QdrantIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) -class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): - def __init__(self, config: QdrantConfig) -> None: +class QdrantVectorMemoryAdapter( + InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate +): + def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None: self.config = config self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) self.cache = {} + self.inference_api = inference_api async def initialize(self) -> None: pass @@ -120,7 +124,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): memory_bank.memory_bank_type == MemoryBankType.vector ), f"Only vector banks are supported {memory_bank.memory_bank_type}" - index = BankWithIndex( + index = self._create_bank_with_index( bank=memory_bank, index=QdrantIndex(self.client, memory_bank.identifier), ) @@ -140,7 +144,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): if not bank: raise ValueError(f"Bank {bank_id} not found") - index = BankWithIndex( + index = self._create_bank_with_index( bank=bank, index=QdrantIndex(client=self.client, collection_name=bank_id), ) diff --git a/llama_stack/providers/remote/memory/weaviate/__init__.py b/llama_stack/providers/remote/memory/weaviate/__init__.py index 504bd1508..f7120bec0 100644 --- a/llama_stack/providers/remote/memory/weaviate/__init__.py +++ b/llama_stack/providers/remote/memory/weaviate/__init__.py @@ -4,12 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401 -async def get_adapter_impl(config: WeaviateConfig, _deps): +async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]): from .weaviate import WeaviateMemoryAdapter - impl = WeaviateMemoryAdapter(config) + impl = WeaviateMemoryAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/memory/weaviate/weaviate.py b/llama_stack/providers/remote/memory/weaviate/weaviate.py index f8fba5c0b..954bdcc68 100644 --- a/llama_stack/providers/remote/memory/weaviate/weaviate.py +++ b/llama_stack/providers/remote/memory/weaviate/weaviate.py @@ -19,6 +19,7 @@ from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, + InferenceEmbeddingMixin, ) from .config import WeaviateConfig, WeaviateRequestProviderData @@ -82,10 +83,14 @@ class WeaviateIndex(EmbeddingIndex): class WeaviateMemoryAdapter( - Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate + InferenceEmbeddingMixin, + Memory, + NeedsRequestProviderData, + MemoryBanksProtocolPrivate, ): - def __init__(self, config: WeaviateConfig) -> None: + def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None: self.config = config + self.inference_api = inference_api self.client_cache = {} self.cache = {} @@ -135,11 +140,10 @@ class WeaviateMemoryAdapter( ], ) - index = BankWithIndex( - bank=memory_bank, - index=WeaviateIndex(client=client, collection_name=memory_bank.identifier), + self.cache[memory_bank.identifier] = self._create_bank_with_index( + memory_bank, + WeaviateIndex(client=client, collection_name=memory_bank.identifier), ) - self.cache[memory_bank.identifier] = index async def list_memory_banks(self) -> List[MemoryBank]: # TODO: right now the Llama Stack is the source of truth for these banks. That is @@ -160,7 +164,7 @@ class WeaviateMemoryAdapter( if not client.collections.exists(bank.identifier): raise ValueError(f"Collection with name `{bank.identifier}` not found") - index = BankWithIndex( + index = self._create_bank_with_index( bank=bank, index=WeaviateIndex(client=client, collection_name=bank_id), ) diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index eb83aa671..8ff91a36e 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -22,28 +22,10 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.providers.datatypes import Api log = logging.getLogger(__name__) -ALL_MINILM_L6_V2_DIMENSION = 384 - -EMBEDDING_MODELS = {} - - -def get_embedding_model(model: str) -> "SentenceTransformer": - global EMBEDDING_MODELS - - loaded_model = EMBEDDING_MODELS.get(model) - if loaded_model is not None: - return loaded_model - - log.info(f"Loading sentence transformer for {model}...") - from sentence_transformers import SentenceTransformer - - loaded_model = SentenceTransformer(model) - EMBEDDING_MODELS[model] = loaded_model - return loaded_model - def parse_pdf(data: bytes) -> str: # For PDF and DOC/DOCX files, we can't reliably convert to string @@ -166,12 +148,12 @@ class EmbeddingIndex(ABC): class BankWithIndex: bank: VectorMemoryBank index: EmbeddingIndex + inference_api: Api.inference async def insert_documents( self, documents: List[MemoryBankDocument], ) -> None: - model = get_embedding_model(self.bank.embedding_model) for doc in documents: content = await content_from_doc(doc) chunks = make_overlapped_chunks( @@ -183,7 +165,10 @@ class BankWithIndex: ) if not chunks: continue - embeddings = model.encode([x.content for x in chunks]).astype(np.float32) + embeddings_response = await self.inference_api.embeddings( + self.bank.embedding_model, [x.content for x in chunks] + ) + embeddings = np.array(embeddings_response.embeddings) await self.index.add_chunks(chunks, embeddings) @@ -208,6 +193,25 @@ class BankWithIndex: else: query_str = _process(query) - model = get_embedding_model(self.bank.embedding_model) - query_vector = model.encode([query_str])[0].astype(np.float32) + embeddings_response = await self.inference_api.embeddings( + self.bank.embedding_model, [query_str] + ) + query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) return await self.index.query(query_vector, k, score_threshold) + + +class InferenceEmbeddingMixin: + inference_api: Api.inference + + def __init__(self, inference_api: Api.inference): + self.inference_api = inference_api + + def _create_bank_with_index( + self, bank: VectorMemoryBank, index: EmbeddingIndex + ) -> BankWithIndex: + + return BankWithIndex( + bank=bank, + index=index, + inference_api=self.inference_api, + )