From 4f8b73b9e11e73d586cbe9a638f11135aeca9f57 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 12 Dec 2024 11:16:54 -0800 Subject: [PATCH] Vector store inference api (#598) # What does this PR do? Moves all the memory providers to use the inference API and improved the memory tests to setup the inference stack correctly and use the embedding models ## Test Plan torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.2-3B-Instruct" --embedding-model="sentence-transformers/all-MiniLM-L6-v2" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=weaviate" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY= --env WEAVIATE_API_KEY=foo --env WEAVIATE_CLUSTER_URL=bar pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=chroma" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=--env CHROMA_HOST=localhost --env CHROMA_PORT=8000 pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=pgvector" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env PGVECTOR_DB=postgres --env PGVECTOR_USER=postgres --env PGVECTOR_PASSWORD=mysecretpassword --env PGVECTOR_HOST=0.0.0.0 --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY= pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=faiss" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY= --- .../providers/inline/memory/faiss/__init__.py | 7 +- .../providers/inline/memory/faiss/faiss.py | 41 ++++++---- llama_stack/providers/registry/memory.py | 7 ++ .../remote/memory/chroma/__init__.py | 7 +- .../providers/remote/memory/chroma/chroma.py | 23 +++--- .../remote/memory/pgvector/__init__.py | 8 +- .../remote/memory/pgvector/pgvector.py | 43 ++++------ .../remote/memory/qdrant/__init__.py | 8 +- .../providers/remote/memory/qdrant/qdrant.py | 5 +- .../remote/memory/weaviate/__init__.py | 8 +- .../remote/memory/weaviate/weaviate.py | 27 +++++-- .../providers/tests/memory/conftest.py | 80 +++++++++++++++++-- .../providers/tests/memory/fixtures.py | 30 +++++-- .../providers/tests/memory/test_memory.py | 26 +++--- .../providers/utils/memory/vector_store.py | 33 +++----- 15 files changed, 235 insertions(+), 118 deletions(-) diff --git a/llama_stack/providers/inline/memory/faiss/__init__.py b/llama_stack/providers/inline/memory/faiss/__init__.py index 16c383be3..2d7ede3b1 100644 --- a/llama_stack/providers/inline/memory/faiss/__init__.py +++ b/llama_stack/providers/inline/memory/faiss/__init__.py @@ -4,16 +4,19 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec from .config import FaissImplConfig -async def get_provider_impl(config: FaissImplConfig, _deps): +async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]): from .faiss import FaissMemoryImpl assert isinstance( config, FaissImplConfig ), f"Unexpected config type: {type(config)}" - impl = FaissMemoryImpl(config) + impl = FaissMemoryImpl(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/memory/faiss/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py index 78de13120..7c27aca85 100644 --- a/llama_stack/providers/inline/memory/faiss/faiss.py +++ b/llama_stack/providers/inline/memory/faiss/faiss.py @@ -19,11 +19,10 @@ from numpy.typing import NDArray from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.memory.vector_store import ( - ALL_MINILM_L6_V2_DIMENSION, BankWithIndex, EmbeddingIndex, ) @@ -32,7 +31,8 @@ from .config import FaissImplConfig logger = logging.getLogger(__name__) -MEMORY_BANKS_PREFIX = "memory_banks:v1::" +MEMORY_BANKS_PREFIX = "memory_banks:v2::" +FAISS_INDEX_PREFIX = "faiss_index:v2::" class FaissIndex(EmbeddingIndex): @@ -56,7 +56,7 @@ class FaissIndex(EmbeddingIndex): if not self.kvstore: return - index_key = f"faiss_index:v1::{self.bank_id}" + index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}" stored_data = await self.kvstore.get(index_key) if stored_data: @@ -85,16 +85,25 @@ class FaissIndex(EmbeddingIndex): "faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"), } - index_key = f"faiss_index:v1::{self.bank_id}" + index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}" await self.kvstore.set(key=index_key, value=json.dumps(data)) async def delete(self): if not self.kvstore or not self.bank_id: return - await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}") + await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}") async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + # Add dimension check + embedding_dim = ( + embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0] + ) + if embedding_dim != self.index.d: + raise ValueError( + f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}" + ) + indexlen = len(self.id_by_index) for i, chunk in enumerate(chunks): self.chunk_by_index[indexlen + i] = chunk @@ -124,8 +133,9 @@ class FaissIndex(EmbeddingIndex): class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): - def __init__(self, config: FaissImplConfig) -> None: + def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None: self.config = config + self.inference_api = inference_api self.cache = {} self.kvstore = None @@ -139,10 +149,11 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): for bank_data in stored_banks: bank = VectorMemoryBank.model_validate_json(bank_data) index = BankWithIndex( - bank=bank, - index=await FaissIndex.create( - ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier + bank, + await FaissIndex.create( + bank.embedding_dimension, self.kvstore, bank.identifier ), + self.inference_api, ) self.cache[bank.identifier] = index @@ -166,13 +177,13 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): ) # Store in cache - index = BankWithIndex( - bank=memory_bank, - index=await FaissIndex.create( - ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier + self.cache[memory_bank.identifier] = BankWithIndex( + memory_bank, + await FaissIndex.create( + memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier ), + self.inference_api, ) - self.cache[memory_bank.identifier] = index async def list_memory_banks(self) -> List[MemoryBank]: return [i.bank for i in self.cache.values()] diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index ff0926108..8bc3d2e7b 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -39,6 +39,7 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.inline.memory.faiss", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", deprecation_warning="Please use the `inline::faiss` provider instead.", + api_dependencies=[Api.inference], ), InlineProviderSpec( api=Api.memory, @@ -46,6 +47,7 @@ def available_providers() -> List[ProviderSpec]: pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], module="llama_stack.providers.inline.memory.faiss", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", + api_dependencies=[Api.inference], ), remote_provider_spec( Api.memory, @@ -55,6 +57,7 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.remote.memory.chroma", config_class="llama_stack.distribution.datatypes.RemoteProviderConfig", ), + api_dependencies=[Api.inference], ), remote_provider_spec( Api.memory, @@ -64,6 +67,7 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.remote.memory.pgvector", config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig", ), + api_dependencies=[Api.inference], ), remote_provider_spec( Api.memory, @@ -74,6 +78,7 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig", provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData", ), + api_dependencies=[Api.inference], ), remote_provider_spec( api=Api.memory, @@ -83,6 +88,7 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.remote.memory.sample", config_class="llama_stack.providers.remote.memory.sample.SampleConfig", ), + api_dependencies=[], ), remote_provider_spec( Api.memory, @@ -92,5 +98,6 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.remote.memory.qdrant", config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig", ), + api_dependencies=[Api.inference], ), ] diff --git a/llama_stack/providers/remote/memory/chroma/__init__.py b/llama_stack/providers/remote/memory/chroma/__init__.py index dfd5c5696..936fabba1 100644 --- a/llama_stack/providers/remote/memory/chroma/__init__.py +++ b/llama_stack/providers/remote/memory/chroma/__init__.py @@ -4,12 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + from llama_stack.distribution.datatypes import RemoteProviderConfig +from llama_stack.providers.datatypes import Api, ProviderSpec -async def get_adapter_impl(config: RemoteProviderConfig, _deps): +async def get_adapter_impl(config: RemoteProviderConfig, deps: Dict[Api, ProviderSpec]): from .chroma import ChromaMemoryAdapter - impl = ChromaMemoryAdapter(config.url) + impl = ChromaMemoryAdapter(config.url, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index 207f6b54d..f073feda3 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -15,8 +15,7 @@ from numpy.typing import NDArray from pydantic import parse_obj_as from llama_stack.apis.memory import * # noqa: F403 - -from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, @@ -72,7 +71,7 @@ class ChromaIndex(EmbeddingIndex): class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): - def __init__(self, url: str) -> None: + def __init__(self, url: str, inference_api: Api.inference) -> None: log.info(f"Initializing ChromaMemoryAdapter with url: {url}") url = url.rstrip("/") parsed = urlparse(url) @@ -82,6 +81,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): self.host = parsed.hostname self.port = parsed.port + self.inference_api = inference_api self.client = None self.cache = {} @@ -109,10 +109,9 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): name=memory_bank.identifier, metadata={"bank": memory_bank.model_dump_json()}, ) - bank_index = BankWithIndex( - bank=memory_bank, index=ChromaIndex(self.client, collection) + self.cache[memory_bank.identifier] = BankWithIndex( + memory_bank, ChromaIndex(self.client, collection), self.inference_api ) - self.cache[memory_bank.identifier] = bank_index async def list_memory_banks(self) -> List[MemoryBank]: collections = await self.client.list_collections() @@ -124,11 +123,11 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): log.exception(f"Failed to parse bank: {collection.metadata}") continue - index = BankWithIndex( - bank=bank, - index=ChromaIndex(self.client, collection), + self.cache[bank.identifier] = BankWithIndex( + bank, + ChromaIndex(self.client, collection), + self.inference_api, ) - self.cache[bank.identifier] = index return [i.bank for i in self.cache.values()] @@ -166,6 +165,8 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): collection = await self.client.get_collection(bank_id) if not collection: raise ValueError(f"Bank {bank_id} not found in Chroma") - index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection)) + index = BankWithIndex( + bank, ChromaIndex(self.client, collection), self.inference_api + ) self.cache[bank_id] = index return index diff --git a/llama_stack/providers/remote/memory/pgvector/__init__.py b/llama_stack/providers/remote/memory/pgvector/__init__.py index 4ac30452f..b4620cae0 100644 --- a/llama_stack/providers/remote/memory/pgvector/__init__.py +++ b/llama_stack/providers/remote/memory/pgvector/__init__.py @@ -4,12 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + from .config import PGVectorConfig -async def get_adapter_impl(config: PGVectorConfig, _deps): +async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]): from .pgvector import PGVectorMemoryAdapter - impl = PGVectorMemoryAdapter(config) + impl = PGVectorMemoryAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index d77de7b41..ed1e61a67 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -16,9 +16,9 @@ from pydantic import BaseModel, parse_obj_as from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate + from llama_stack.providers.utils.memory.vector_store import ( - ALL_MINILM_L6_V2_DIMENSION, BankWithIndex, EmbeddingIndex, ) @@ -120,8 +120,9 @@ class PGVectorIndex(EmbeddingIndex): class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): - def __init__(self, config: PGVectorConfig) -> None: + def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None: self.config = config + self.inference_api = inference_api self.cursor = None self.conn = None self.cache = {} @@ -160,27 +161,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): async def shutdown(self) -> None: pass - async def register_memory_bank( - self, - memory_bank: MemoryBank, - ) -> None: + async def register_memory_bank(self, memory_bank: MemoryBank) -> None: assert ( memory_bank.memory_bank_type == MemoryBankType.vector.value ), f"Only vector banks are supported {memory_bank.memory_bank_type}" - upsert_models( - self.cursor, - [ - (memory_bank.identifier, memory_bank), - ], + upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)]) + index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor) + self.cache[memory_bank.identifier] = BankWithIndex( + memory_bank, index, self.inference_api ) - index = BankWithIndex( - bank=memory_bank, - index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), - ) - self.cache[memory_bank.identifier] = index - async def unregister_memory_bank(self, memory_bank_id: str) -> None: await self.cache[memory_bank_id].index.delete() del self.cache[memory_bank_id] @@ -190,8 +181,9 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): for bank in banks: if bank.identifier not in self.cache: index = BankWithIndex( - bank=bank, - index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), + bank, + PGVectorIndex(bank, bank.embedding_dimension, self.cursor), + self.inference_api, ) self.cache[bank.identifier] = index return banks @@ -214,14 +206,13 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): index = await self._get_and_cache_bank_index(bank_id) return await index.query_documents(query, params) + self.inference_api = inference_api + async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex: if bank_id in self.cache: return self.cache[bank_id] bank = await self.memory_bank_store.get_memory_bank(bank_id) - index = BankWithIndex( - bank=bank, - index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), - ) - self.cache[bank_id] = index - return index + index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor) + self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api) + return self.cache[bank_id] diff --git a/llama_stack/providers/remote/memory/qdrant/__init__.py b/llama_stack/providers/remote/memory/qdrant/__init__.py index 9f54babad..54605fcf9 100644 --- a/llama_stack/providers/remote/memory/qdrant/__init__.py +++ b/llama_stack/providers/remote/memory/qdrant/__init__.py @@ -4,12 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + from .config import QdrantConfig -async def get_adapter_impl(config: QdrantConfig, _deps): +async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]): from .qdrant import QdrantVectorMemoryAdapter - impl = QdrantVectorMemoryAdapter(config) + impl = QdrantVectorMemoryAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/memory/qdrant/qdrant.py b/llama_stack/providers/remote/memory/qdrant/qdrant.py index be370eec9..f2f28e63a 100644 --- a/llama_stack/providers/remote/memory/qdrant/qdrant.py +++ b/llama_stack/providers/remote/memory/qdrant/qdrant.py @@ -101,10 +101,11 @@ class QdrantIndex(EmbeddingIndex): class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): - def __init__(self, config: QdrantConfig) -> None: + def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None: self.config = config self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) self.cache = {} + self.inference_api = inference_api async def initialize(self) -> None: pass @@ -123,6 +124,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): index = BankWithIndex( bank=memory_bank, index=QdrantIndex(self.client, memory_bank.identifier), + inference_api=self.inference_api, ) self.cache[memory_bank.identifier] = index @@ -143,6 +145,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): index = BankWithIndex( bank=bank, index=QdrantIndex(client=self.client, collection_name=bank_id), + inference_api=self.inference_api, ) self.cache[bank_id] = index return index diff --git a/llama_stack/providers/remote/memory/weaviate/__init__.py b/llama_stack/providers/remote/memory/weaviate/__init__.py index 504bd1508..f7120bec0 100644 --- a/llama_stack/providers/remote/memory/weaviate/__init__.py +++ b/llama_stack/providers/remote/memory/weaviate/__init__.py @@ -4,12 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401 -async def get_adapter_impl(config: WeaviateConfig, _deps): +async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]): from .weaviate import WeaviateMemoryAdapter - impl = WeaviateMemoryAdapter(config) + impl = WeaviateMemoryAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/memory/weaviate/weaviate.py b/llama_stack/providers/remote/memory/weaviate/weaviate.py index f8fba5c0b..b409b697b 100644 --- a/llama_stack/providers/remote/memory/weaviate/weaviate.py +++ b/llama_stack/providers/remote/memory/weaviate/weaviate.py @@ -12,10 +12,11 @@ import weaviate import weaviate.classes as wvc from numpy.typing import NDArray from weaviate.classes.init import Auth +from weaviate.classes.query import Filter from llama_stack.apis.memory import * # noqa: F403 from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, @@ -80,12 +81,21 @@ class WeaviateIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) + async def delete(self, chunk_ids: List[str]) -> None: + collection = self.client.collections.get(self.collection_name) + collection.data.delete_many( + where=Filter.by_property("id").contains_any(chunk_ids) + ) + class WeaviateMemoryAdapter( - Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate + Memory, + NeedsRequestProviderData, + MemoryBanksProtocolPrivate, ): - def __init__(self, config: WeaviateConfig) -> None: + def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None: self.config = config + self.inference_api = inference_api self.client_cache = {} self.cache = {} @@ -117,7 +127,7 @@ class WeaviateMemoryAdapter( memory_bank: MemoryBank, ) -> None: assert ( - memory_bank.memory_bank_type == MemoryBankType.vector + memory_bank.memory_bank_type == MemoryBankType.vector.value ), f"Only vector banks are supported {memory_bank.memory_bank_type}" client = self._get_client() @@ -135,11 +145,11 @@ class WeaviateMemoryAdapter( ], ) - index = BankWithIndex( - bank=memory_bank, - index=WeaviateIndex(client=client, collection_name=memory_bank.identifier), + self.cache[memory_bank.identifier] = BankWithIndex( + memory_bank, + WeaviateIndex(client=client, collection_name=memory_bank.identifier), + self.inference_api, ) - self.cache[memory_bank.identifier] = index async def list_memory_banks(self) -> List[MemoryBank]: # TODO: right now the Llama Stack is the source of truth for these banks. That is @@ -163,6 +173,7 @@ class WeaviateMemoryAdapter( index = BankWithIndex( bank=bank, index=WeaviateIndex(client=client, collection_name=bank_id), + inference_api=self.inference_api, ) self.cache[bank_id] = index return index diff --git a/llama_stack/providers/tests/memory/conftest.py b/llama_stack/providers/tests/memory/conftest.py index 99ecbe794..023a1a156 100644 --- a/llama_stack/providers/tests/memory/conftest.py +++ b/llama_stack/providers/tests/memory/conftest.py @@ -6,9 +6,65 @@ import pytest +from ..conftest import get_provider_fixture_overrides + +from ..inference.fixtures import INFERENCE_FIXTURES from .fixtures import MEMORY_FIXTURES +DEFAULT_PROVIDER_COMBINATIONS = [ + pytest.param( + { + "inference": "meta_reference", + "memory": "faiss", + }, + id="meta_reference", + marks=pytest.mark.meta_reference, + ), + pytest.param( + { + "inference": "ollama", + "memory": "pgvector", + }, + id="ollama", + marks=pytest.mark.ollama, + ), + pytest.param( + { + "inference": "together", + "memory": "chroma", + }, + id="chroma", + marks=pytest.mark.chroma, + ), + pytest.param( + { + "inference": "bedrock", + "memory": "qdrant", + }, + id="qdrant", + marks=pytest.mark.qdrant, + ), + pytest.param( + { + "inference": "fireworks", + "memory": "weaviate", + }, + id="weaviate", + marks=pytest.mark.weaviate, + ), +] + + +def pytest_addoption(parser): + parser.addoption( + "--embedding-model", + action="store", + default=None, + help="Specify the embedding model to use for testing", + ) + + def pytest_configure(config): for fixture_name in MEMORY_FIXTURES: config.addinivalue_line( @@ -18,12 +74,22 @@ def pytest_configure(config): def pytest_generate_tests(metafunc): + if "embedding_model" in metafunc.fixturenames: + model = metafunc.config.getoption("--embedding-model") + if not model: + raise ValueError( + "No embedding model specified. Please provide a valid embedding model." + ) + params = [pytest.param(model, id="")] + + metafunc.parametrize("embedding_model", params, indirect=True) if "memory_stack" in metafunc.fixturenames: - metafunc.parametrize( - "memory_stack", - [ - pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name)) - for fixture_name in MEMORY_FIXTURES - ], - indirect=True, + available_fixtures = { + "inference": INFERENCE_FIXTURES, + "memory": MEMORY_FIXTURES, + } + combinations = ( + get_provider_fixture_overrides(metafunc.config, available_fixtures) + or DEFAULT_PROVIDER_COMBINATIONS ) + metafunc.parametrize("memory_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index c9559b61c..b5396b3ac 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -10,6 +10,8 @@ import tempfile import pytest import pytest_asyncio +from llama_stack.apis.inference import ModelInput, ModelType + from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig @@ -97,14 +99,30 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"] @pytest_asyncio.fixture(scope="session") -async def memory_stack(request): - fixture_name = request.param - fixture = request.getfixturevalue(f"memory_{fixture_name}") +async def memory_stack(embedding_model, request): + fixture_dict = request.param + + providers = {} + provider_data = {} + for key in ["inference", "memory"]: + fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") + providers[key] = fixture.providers + if fixture.provider_data: + provider_data.update(fixture.provider_data) test_stack = await construct_stack_for_test( - [Api.memory], - {"memory": fixture.providers}, - fixture.provider_data, + [Api.memory, Api.inference], + providers, + provider_data, + models=[ + ModelInput( + model_id=embedding_model, + model_type=ModelType.embedding_model, + metadata={ + "embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"), + }, + ) + ], ) return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks] diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index b6e2e0a76..526aa646c 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -45,12 +45,14 @@ def sample_documents(): ] -async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank: +async def register_memory_bank( + banks_impl: MemoryBanks, embedding_model: str +) -> MemoryBank: bank_id = f"test_bank_{uuid.uuid4().hex}" return await banks_impl.register_memory_bank( memory_bank_id=bank_id, params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", + embedding_model=embedding_model, chunk_size_in_tokens=512, overlap_size_in_tokens=64, ), @@ -59,11 +61,11 @@ async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank: class TestMemory: @pytest.mark.asyncio - async def test_banks_list(self, memory_stack): + async def test_banks_list(self, memory_stack, embedding_model): _, banks_impl = memory_stack # Register a test bank - registered_bank = await register_memory_bank(banks_impl) + registered_bank = await register_memory_bank(banks_impl, embedding_model) try: # Verify our bank shows up in list @@ -84,7 +86,7 @@ class TestMemory: ) @pytest.mark.asyncio - async def test_banks_register(self, memory_stack): + async def test_banks_register(self, memory_stack, embedding_model): _, banks_impl = memory_stack bank_id = f"test_bank_{uuid.uuid4().hex}" @@ -94,7 +96,7 @@ class TestMemory: await banks_impl.register_memory_bank( memory_bank_id=bank_id, params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", + embedding_model=embedding_model, chunk_size_in_tokens=512, overlap_size_in_tokens=64, ), @@ -109,7 +111,7 @@ class TestMemory: await banks_impl.register_memory_bank( memory_bank_id=bank_id, params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", + embedding_model=embedding_model, chunk_size_in_tokens=512, overlap_size_in_tokens=64, ), @@ -126,13 +128,15 @@ class TestMemory: await banks_impl.unregister_memory_bank(bank_id) @pytest.mark.asyncio - async def test_query_documents(self, memory_stack, sample_documents): + async def test_query_documents( + self, memory_stack, embedding_model, sample_documents + ): memory_impl, banks_impl = memory_stack with pytest.raises(ValueError): await memory_impl.insert_documents("test_bank", sample_documents) - registered_bank = await register_memory_bank(banks_impl) + registered_bank = await register_memory_bank(banks_impl, embedding_model) await memory_impl.insert_documents( registered_bank.memory_bank_id, sample_documents ) @@ -165,13 +169,13 @@ class TestMemory: # Test case 5: Query with threshold on similarity score query5 = "quantum computing" # Not directly related to any document - params5 = {"score_threshold": 0.2} + params5 = {"score_threshold": 0.01} response5 = await memory_impl.query_documents( registered_bank.memory_bank_id, query5, params5 ) assert_valid_response(response5) print("The scores are:", response5.scores) - assert all(score >= 0.2 for score in response5.scores) + assert all(score >= 0.01 for score in response5.scores) def assert_valid_response(response: QueryDocumentsResponse): diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index eb83aa671..cebe897bc 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -22,28 +22,10 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.providers.datatypes import Api log = logging.getLogger(__name__) -ALL_MINILM_L6_V2_DIMENSION = 384 - -EMBEDDING_MODELS = {} - - -def get_embedding_model(model: str) -> "SentenceTransformer": - global EMBEDDING_MODELS - - loaded_model = EMBEDDING_MODELS.get(model) - if loaded_model is not None: - return loaded_model - - log.info(f"Loading sentence transformer for {model}...") - from sentence_transformers import SentenceTransformer - - loaded_model = SentenceTransformer(model) - EMBEDDING_MODELS[model] = loaded_model - return loaded_model - def parse_pdf(data: bytes) -> str: # For PDF and DOC/DOCX files, we can't reliably convert to string @@ -166,12 +148,12 @@ class EmbeddingIndex(ABC): class BankWithIndex: bank: VectorMemoryBank index: EmbeddingIndex + inference_api: Api.inference async def insert_documents( self, documents: List[MemoryBankDocument], ) -> None: - model = get_embedding_model(self.bank.embedding_model) for doc in documents: content = await content_from_doc(doc) chunks = make_overlapped_chunks( @@ -183,7 +165,10 @@ class BankWithIndex: ) if not chunks: continue - embeddings = model.encode([x.content for x in chunks]).astype(np.float32) + embeddings_response = await self.inference_api.embeddings( + self.bank.embedding_model, [x.content for x in chunks] + ) + embeddings = np.array(embeddings_response.embeddings) await self.index.add_chunks(chunks, embeddings) @@ -208,6 +193,8 @@ class BankWithIndex: else: query_str = _process(query) - model = get_embedding_model(self.bank.embedding_model) - query_vector = model.encode([query_str])[0].astype(np.float32) + embeddings_response = await self.inference_api.embeddings( + self.bank.embedding_model, [query_str] + ) + query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) return await self.index.query(query_vector, k, score_threshold)