diff --git a/llama_stack/providers/impls/meta_reference/memory/tests/test_faiss.py b/llama_stack/providers/impls/meta_reference/memory/tests/test_faiss.py index fbc8ace63..b09abc2ed 100644 --- a/llama_stack/providers/impls/meta_reference/memory/tests/test_faiss.py +++ b/llama_stack/providers/impls/meta_reference/memory/tests/test_faiss.py @@ -1,116 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import tempfile + import pytest -import numpy as np -from unittest.mock import AsyncMock, MagicMock - -from llama_stack.providers.impls.meta_reference.memory.faiss import ( - FaissIndex, - FaissMemoryImpl, - MEMORY_BANKS_PREFIX, -) +from llama_stack.apis.memory import MemoryBankType, VectorMemoryBankDef from llama_stack.providers.impls.meta_reference.memory.config import FaissImplConfig -from llama_stack.providers.utils.memory.vector_store import ALL_MINILM_L6_V2_DIMENSION -from llama_stack.apis.memory import ( - Chunk, - QueryDocumentsResponse, - VectorMemoryBankDef, - MemoryBankType, -) - -@pytest.fixture -def faiss_index(): - return FaissIndex(dimension=ALL_MINILM_L6_V2_DIMENSION) - - -@pytest.fixture -def sample_chunks(): - return [ - Chunk( - document_id="doc1", - content="This is the first test chunk", - metadata={"test": "metadata1"}, - token_count=7 - ), - Chunk( - document_id="doc2", - content="This is the second test chunk", - metadata={"test": "metadata2"}, - token_count=7 - ), - ] - - -@pytest.fixture -def sample_embeddings(): - return np.array([ - [1.0, 0.0] + [0.0] * (ALL_MINILM_L6_V2_DIMENSION - 2), - [0.0, 1.0] + [0.0] * (ALL_MINILM_L6_V2_DIMENSION - 2), - ], dtype=np.float32) - - -class TestFaissIndex: - @pytest.mark.asyncio - async def test_add_chunks(self, faiss_index, sample_chunks, sample_embeddings): - await faiss_index.add_chunks(sample_chunks, sample_embeddings) - - assert len(faiss_index.id_by_index) == 2 - assert len(faiss_index.chunk_by_index) == 2 - assert faiss_index.id_by_index[0] == "doc1" - assert faiss_index.id_by_index[1] == "doc2" - assert faiss_index.chunk_by_index[0].content == "This is the first test chunk" - assert faiss_index.chunk_by_index[1].content == "This is the second test chunk" - - @pytest.mark.asyncio - async def test_query(self, faiss_index, sample_chunks, sample_embeddings): - await faiss_index.add_chunks(sample_chunks, sample_embeddings) - - # Query vector closer to first chunk - query_vector = np.array([[0.9, 0.1] + [0.0] * (ALL_MINILM_L6_V2_DIMENSION - 2)], dtype=np.float32) - response = await faiss_index.query(query_vector, k=2, score_threshold=0.0) - - assert isinstance(response, QueryDocumentsResponse) - assert len(response.chunks) == 2 - assert len(response.scores) == 2 - assert response.chunks[0].document_id == "doc1" - assert response.chunks[1].document_id == "doc2" - - @pytest.mark.asyncio - async def test_query_with_threshold(self, faiss_index, sample_chunks, sample_embeddings): - await faiss_index.add_chunks(sample_chunks, sample_embeddings) - - # Query vector far from both chunks - query_vector = np.array([[0.1, 0.1] + [1.0] * (ALL_MINILM_L6_V2_DIMENSION - 2)], dtype=np.float32) - # Increase threshold to 0.99 to ensure no matches - response = await faiss_index.query(query_vector, k=2, score_threshold=0.99) - - assert isinstance(response, QueryDocumentsResponse) - assert len(response.chunks) == 0 - assert len(response.scores) == 0 +from llama_stack.providers.impls.meta_reference.memory.faiss import FaissMemoryImpl +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig class TestFaissMemoryImpl: @pytest.fixture - def mock_kvstore(self): - mock = AsyncMock() - mock.range = AsyncMock(return_value=[]) - return mock - - @pytest.fixture - def faiss_impl(self, mock_kvstore): - config = FaissImplConfig() - impl = FaissMemoryImpl(config) - impl.kvstore = mock_kvstore - return impl + def faiss_impl(self): + # Create a temporary SQLite database file + temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + config = FaissImplConfig(kvstore=SqliteKVStoreConfig(db_path=temp_db.name)) + return FaissMemoryImpl(config) @pytest.mark.asyncio - async def test_initialize(self, faiss_impl, mock_kvstore): + async def test_initialize(self, faiss_impl): # Test empty initialization - mock_kvstore.range.reset_mock() # Reset mock before test await faiss_impl.initialize() - mock_kvstore.range.assert_called_once_with( - MEMORY_BANKS_PREFIX, - f"{MEMORY_BANKS_PREFIX}\xff" - ) assert len(faiss_impl.cache) == 0 # Test initialization with existing banks @@ -121,11 +36,16 @@ class TestFaissMemoryImpl: chunk_size_in_tokens=512, overlap_size_in_tokens=64, ) - mock_kvstore.range.return_value = [bank.json()] - - await faiss_impl.initialize() - assert len(faiss_impl.cache) == 1 - assert "test_bank" in faiss_impl.cache + + # Register a bank and reinitialize to test loading + await faiss_impl.register_memory_bank(bank) + + # Create new instance to test initialization with existing data + new_impl = FaissMemoryImpl(faiss_impl.config) + await new_impl.initialize() + + assert len(new_impl.cache) == 1 + assert "test_bank" in new_impl.cache @pytest.mark.asyncio async def test_register_memory_bank(self, faiss_impl): @@ -136,29 +56,18 @@ class TestFaissMemoryImpl: chunk_size_in_tokens=512, overlap_size_in_tokens=64, ) - + + await faiss_impl.initialize() await faiss_impl.register_memory_bank(bank) - - faiss_impl.kvstore.set.assert_called_once_with( - key=f"{MEMORY_BANKS_PREFIX}test_bank", - value=bank.json(), - ) + assert "test_bank" in faiss_impl.cache assert faiss_impl.cache["test_bank"].bank == bank - @pytest.mark.asyncio - async def test_register_invalid_bank_type(self, faiss_impl): - bank = VectorMemoryBankDef( - identifier="test_bank", - type=MemoryBankType.vector, # Use enum value directly instead of string - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ) - - # Change test to verify successful registration instead - await faiss_impl.register_memory_bank(bank) - assert "test_bank" in faiss_impl.cache + # Verify persistence + new_impl = FaissMemoryImpl(faiss_impl.config) + await new_impl.initialize() + assert "test_bank" in new_impl.cache + if __name__ == "__main__": pytest.main([__file__])