From 0e451525e5bf289814c674e5da85b393465f6da9 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 9 Dec 2024 15:00:12 -0800 Subject: [PATCH] remove mixin and test fixes --- .../providers/inline/memory/faiss/faiss.py | 10 ++- .../providers/remote/memory/chroma/chroma.py | 17 ++-- .../remote/memory/pgvector/pgvector.py | 14 ++-- .../providers/remote/memory/qdrant/qdrant.py | 11 ++- .../remote/memory/weaviate/weaviate.py | 8 +- .../providers/tests/memory/conftest.py | 80 +++++++++++++++++-- .../providers/tests/memory/fixtures.py | 30 +++++-- .../providers/tests/memory/test_memory.py | 22 ++--- .../providers/utils/memory/vector_store.py | 17 ---- 9 files changed, 140 insertions(+), 69 deletions(-) diff --git a/llama_stack/providers/inline/memory/faiss/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py index cb090c870..9c52930be 100644 --- a/llama_stack/providers/inline/memory/faiss/faiss.py +++ b/llama_stack/providers/inline/memory/faiss/faiss.py @@ -23,8 +23,8 @@ from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.memory.vector_store import ( + BankWithIndex, EmbeddingIndex, - InferenceEmbeddingMixin, ) from .config import FaissImplConfig @@ -131,7 +131,7 @@ class FaissIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) -class FaissMemoryImpl(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate): +class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None: self.config = config self.inference_api = inference_api @@ -147,11 +147,12 @@ class FaissMemoryImpl(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivat for bank_data in stored_banks: bank = VectorMemoryBank.model_validate_json(bank_data) - index = self._create_bank_with_index( + index = BankWithIndex( bank, await FaissIndex.create( bank.embedding_dimension, self.kvstore, bank.identifier ), + self.inference_api, ) self.cache[bank.identifier] = index @@ -175,11 +176,12 @@ class FaissMemoryImpl(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivat ) # Store in cache - self.cache[memory_bank.identifier] = self._create_bank_with_index( + self.cache[memory_bank.identifier] = BankWithIndex( memory_bank, await FaissIndex.create( memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier ), + self.inference_api, ) async def list_memory_banks(self) -> List[MemoryBank]: diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index f2b48a3be..f073feda3 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -15,12 +15,10 @@ from numpy.typing import NDArray from pydantic import parse_obj_as from llama_stack.apis.memory import * # noqa: F403 - -from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, - InferenceEmbeddingMixin, ) log = logging.getLogger(__name__) @@ -72,7 +70,7 @@ class ChromaIndex(EmbeddingIndex): await self.client.delete_collection(self.collection.name) -class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate): +class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): def __init__(self, url: str, inference_api: Api.inference) -> None: log.info(f"Initializing ChromaMemoryAdapter with url: {url}") url = url.rstrip("/") @@ -111,8 +109,8 @@ class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPr name=memory_bank.identifier, metadata={"bank": memory_bank.model_dump_json()}, ) - self.cache[memory_bank.identifier] = self._create_bank_with_index( - memory_bank, ChromaIndex(self.client, collection) + self.cache[memory_bank.identifier] = BankWithIndex( + memory_bank, ChromaIndex(self.client, collection), self.inference_api ) async def list_memory_banks(self) -> List[MemoryBank]: @@ -125,9 +123,10 @@ class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPr log.exception(f"Failed to parse bank: {collection.metadata}") continue - self.cache[bank.identifier] = self._create_bank_with_index( + self.cache[bank.identifier] = BankWithIndex( bank, ChromaIndex(self.client, collection), + self.inference_api, ) return [i.bank for i in self.cache.values()] @@ -166,6 +165,8 @@ class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPr collection = await self.client.get_collection(bank_id) if not collection: raise ValueError(f"Bank {bank_id} not found in Chroma") - index = self._create_bank_with_index(bank, ChromaIndex(self.client, collection)) + index = BankWithIndex( + bank, ChromaIndex(self.client, collection), self.inference_api + ) self.cache[bank_id] = index return index diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index 18d732534..ed1e61a67 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -21,7 +21,6 @@ from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, - InferenceEmbeddingMixin, ) from .config import PGVectorConfig @@ -120,9 +119,7 @@ class PGVectorIndex(EmbeddingIndex): self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}") -class PGVectorMemoryAdapter( - InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate -): +class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None: self.config = config self.inference_api = inference_api @@ -171,8 +168,8 @@ class PGVectorMemoryAdapter( upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)]) index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor) - self.cache[memory_bank.identifier] = self._create_bank_with_index( - memory_bank, index + self.cache[memory_bank.identifier] = BankWithIndex( + memory_bank, index, self.inference_api ) async def unregister_memory_bank(self, memory_bank_id: str) -> None: @@ -183,9 +180,10 @@ class PGVectorMemoryAdapter( banks = load_models(self.cursor, VectorMemoryBank) for bank in banks: if bank.identifier not in self.cache: - index = self._create_bank_with_index( + index = BankWithIndex( bank, PGVectorIndex(bank, bank.embedding_dimension, self.cursor), + self.inference_api, ) self.cache[bank.identifier] = index return banks @@ -216,5 +214,5 @@ class PGVectorMemoryAdapter( bank = await self.memory_bank_store.get_memory_bank(bank_id) index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor) - self.cache[bank_id] = self._create_bank_with_index(bank, index) + self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api) return self.cache[bank_id] diff --git a/llama_stack/providers/remote/memory/qdrant/qdrant.py b/llama_stack/providers/remote/memory/qdrant/qdrant.py index f2c36438d..f2f28e63a 100644 --- a/llama_stack/providers/remote/memory/qdrant/qdrant.py +++ b/llama_stack/providers/remote/memory/qdrant/qdrant.py @@ -21,7 +21,6 @@ from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, - InferenceEmbeddingMixin, ) log = logging.getLogger(__name__) @@ -101,9 +100,7 @@ class QdrantIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) -class QdrantVectorMemoryAdapter( - InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate -): +class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None: self.config = config self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) @@ -124,9 +121,10 @@ class QdrantVectorMemoryAdapter( memory_bank.memory_bank_type == MemoryBankType.vector ), f"Only vector banks are supported {memory_bank.memory_bank_type}" - index = self._create_bank_with_index( + index = BankWithIndex( bank=memory_bank, index=QdrantIndex(self.client, memory_bank.identifier), + inference_api=self.inference_api, ) self.cache[memory_bank.identifier] = index @@ -144,9 +142,10 @@ class QdrantVectorMemoryAdapter( if not bank: raise ValueError(f"Bank {bank_id} not found") - index = self._create_bank_with_index( + index = BankWithIndex( bank=bank, index=QdrantIndex(client=self.client, collection_name=bank_id), + inference_api=self.inference_api, ) self.cache[bank_id] = index return index diff --git a/llama_stack/providers/remote/memory/weaviate/weaviate.py b/llama_stack/providers/remote/memory/weaviate/weaviate.py index 954bdcc68..3fa9ace51 100644 --- a/llama_stack/providers/remote/memory/weaviate/weaviate.py +++ b/llama_stack/providers/remote/memory/weaviate/weaviate.py @@ -19,7 +19,6 @@ from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, - InferenceEmbeddingMixin, ) from .config import WeaviateConfig, WeaviateRequestProviderData @@ -83,7 +82,6 @@ class WeaviateIndex(EmbeddingIndex): class WeaviateMemoryAdapter( - InferenceEmbeddingMixin, Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate, @@ -140,9 +138,10 @@ class WeaviateMemoryAdapter( ], ) - self.cache[memory_bank.identifier] = self._create_bank_with_index( + self.cache[memory_bank.identifier] = BankWithIndex( memory_bank, WeaviateIndex(client=client, collection_name=memory_bank.identifier), + self.inference_api, ) async def list_memory_banks(self) -> List[MemoryBank]: @@ -164,9 +163,10 @@ class WeaviateMemoryAdapter( if not client.collections.exists(bank.identifier): raise ValueError(f"Collection with name `{bank.identifier}` not found") - index = self._create_bank_with_index( + index = BankWithIndex( bank=bank, index=WeaviateIndex(client=client, collection_name=bank_id), + inference_api=self.inference_api, ) self.cache[bank_id] = index return index diff --git a/llama_stack/providers/tests/memory/conftest.py b/llama_stack/providers/tests/memory/conftest.py index 99ecbe794..023a1a156 100644 --- a/llama_stack/providers/tests/memory/conftest.py +++ b/llama_stack/providers/tests/memory/conftest.py @@ -6,9 +6,65 @@ import pytest +from ..conftest import get_provider_fixture_overrides + +from ..inference.fixtures import INFERENCE_FIXTURES from .fixtures import MEMORY_FIXTURES +DEFAULT_PROVIDER_COMBINATIONS = [ + pytest.param( + { + "inference": "meta_reference", + "memory": "faiss", + }, + id="meta_reference", + marks=pytest.mark.meta_reference, + ), + pytest.param( + { + "inference": "ollama", + "memory": "pgvector", + }, + id="ollama", + marks=pytest.mark.ollama, + ), + pytest.param( + { + "inference": "together", + "memory": "chroma", + }, + id="chroma", + marks=pytest.mark.chroma, + ), + pytest.param( + { + "inference": "bedrock", + "memory": "qdrant", + }, + id="qdrant", + marks=pytest.mark.qdrant, + ), + pytest.param( + { + "inference": "fireworks", + "memory": "weaviate", + }, + id="weaviate", + marks=pytest.mark.weaviate, + ), +] + + +def pytest_addoption(parser): + parser.addoption( + "--embedding-model", + action="store", + default=None, + help="Specify the embedding model to use for testing", + ) + + def pytest_configure(config): for fixture_name in MEMORY_FIXTURES: config.addinivalue_line( @@ -18,12 +74,22 @@ def pytest_configure(config): def pytest_generate_tests(metafunc): + if "embedding_model" in metafunc.fixturenames: + model = metafunc.config.getoption("--embedding-model") + if not model: + raise ValueError( + "No embedding model specified. Please provide a valid embedding model." + ) + params = [pytest.param(model, id="")] + + metafunc.parametrize("embedding_model", params, indirect=True) if "memory_stack" in metafunc.fixturenames: - metafunc.parametrize( - "memory_stack", - [ - pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name)) - for fixture_name in MEMORY_FIXTURES - ], - indirect=True, + available_fixtures = { + "inference": INFERENCE_FIXTURES, + "memory": MEMORY_FIXTURES, + } + combinations = ( + get_provider_fixture_overrides(metafunc.config, available_fixtures) + or DEFAULT_PROVIDER_COMBINATIONS ) + metafunc.parametrize("memory_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index c9559b61c..b5396b3ac 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -10,6 +10,8 @@ import tempfile import pytest import pytest_asyncio +from llama_stack.apis.inference import ModelInput, ModelType + from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig @@ -97,14 +99,30 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"] @pytest_asyncio.fixture(scope="session") -async def memory_stack(request): - fixture_name = request.param - fixture = request.getfixturevalue(f"memory_{fixture_name}") +async def memory_stack(embedding_model, request): + fixture_dict = request.param + + providers = {} + provider_data = {} + for key in ["inference", "memory"]: + fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") + providers[key] = fixture.providers + if fixture.provider_data: + provider_data.update(fixture.provider_data) test_stack = await construct_stack_for_test( - [Api.memory], - {"memory": fixture.providers}, - fixture.provider_data, + [Api.memory, Api.inference], + providers, + provider_data, + models=[ + ModelInput( + model_id=embedding_model, + model_type=ModelType.embedding_model, + metadata={ + "embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"), + }, + ) + ], ) return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks] diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index b6e2e0a76..85f4351f8 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -45,12 +45,14 @@ def sample_documents(): ] -async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank: +async def register_memory_bank( + banks_impl: MemoryBanks, embedding_model: str +) -> MemoryBank: bank_id = f"test_bank_{uuid.uuid4().hex}" return await banks_impl.register_memory_bank( memory_bank_id=bank_id, params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", + embedding_model=embedding_model, chunk_size_in_tokens=512, overlap_size_in_tokens=64, ), @@ -59,11 +61,11 @@ async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank: class TestMemory: @pytest.mark.asyncio - async def test_banks_list(self, memory_stack): + async def test_banks_list(self, memory_stack, embedding_model): _, banks_impl = memory_stack # Register a test bank - registered_bank = await register_memory_bank(banks_impl) + registered_bank = await register_memory_bank(banks_impl, embedding_model) try: # Verify our bank shows up in list @@ -84,7 +86,7 @@ class TestMemory: ) @pytest.mark.asyncio - async def test_banks_register(self, memory_stack): + async def test_banks_register(self, memory_stack, embedding_model): _, banks_impl = memory_stack bank_id = f"test_bank_{uuid.uuid4().hex}" @@ -94,7 +96,7 @@ class TestMemory: await banks_impl.register_memory_bank( memory_bank_id=bank_id, params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", + embedding_model=embedding_model, chunk_size_in_tokens=512, overlap_size_in_tokens=64, ), @@ -109,7 +111,7 @@ class TestMemory: await banks_impl.register_memory_bank( memory_bank_id=bank_id, params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", + embedding_model=embedding_model, chunk_size_in_tokens=512, overlap_size_in_tokens=64, ), @@ -126,13 +128,15 @@ class TestMemory: await banks_impl.unregister_memory_bank(bank_id) @pytest.mark.asyncio - async def test_query_documents(self, memory_stack, sample_documents): + async def test_query_documents( + self, memory_stack, embedding_model, sample_documents + ): memory_impl, banks_impl = memory_stack with pytest.raises(ValueError): await memory_impl.insert_documents("test_bank", sample_documents) - registered_bank = await register_memory_bank(banks_impl) + registered_bank = await register_memory_bank(banks_impl, embedding_model) await memory_impl.insert_documents( registered_bank.memory_bank_id, sample_documents ) diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 8ff91a36e..cebe897bc 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -198,20 +198,3 @@ class BankWithIndex: ) query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) return await self.index.query(query_vector, k, score_threshold) - - -class InferenceEmbeddingMixin: - inference_api: Api.inference - - def __init__(self, inference_api: Api.inference): - self.inference_api = inference_api - - def _create_bank_with_index( - self, bank: VectorMemoryBank, index: EmbeddingIndex - ) -> BankWithIndex: - - return BankWithIndex( - bank=bank, - index=index, - inference_api=self.inference_api, - )