diff --git a/llama_stack/providers/impls/meta_reference/memory/config.py b/llama_stack/providers/impls/meta_reference/memory/config.py index b1c94c889..41970b05f 100644 --- a/llama_stack/providers/impls/meta_reference/memory/config.py +++ b/llama_stack/providers/impls/meta_reference/memory/config.py @@ -5,9 +5,17 @@ # the root directory of this source tree. from llama_models.schema_utils import json_schema_type - from pydantic import BaseModel +from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) + @json_schema_type -class FaissImplConfig(BaseModel): ... +class FaissImplConfig(BaseModel): + kvstore: KVStoreConfig = SqliteKVStoreConfig( + db_path=(RUNTIME_BASE_DIR / "faiss_store.db").as_posix() + ) # Uses SQLite config specific to FAISS storage diff --git a/llama_stack/providers/impls/meta_reference/memory/faiss.py b/llama_stack/providers/impls/meta_reference/memory/faiss.py index 02829f7be..4bd5fd5a7 100644 --- a/llama_stack/providers/impls/meta_reference/memory/faiss.py +++ b/llama_stack/providers/impls/meta_reference/memory/faiss.py @@ -16,6 +16,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403 from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.memory.vector_store import ( ALL_MINILM_L6_V2_DIMENSION, @@ -28,6 +29,8 @@ from .config import FaissImplConfig logger = logging.getLogger(__name__) +MEMORY_BANKS_PREFIX = "memory_banks:" + class FaissIndex(EmbeddingIndex): id_by_index: Dict[int, str] @@ -69,10 +72,25 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): def __init__(self, config: FaissImplConfig) -> None: self.config = config self.cache = {} + self.kvstore = None - async def initialize(self) -> None: ... + async def initialize(self) -> None: + self.kvstore = await kvstore_impl(self.config.kvstore) + # Load existing banks from kvstore + start_key = MEMORY_BANKS_PREFIX + end_key = f"{MEMORY_BANKS_PREFIX}\xff" + stored_banks = await self.kvstore.range(start_key, end_key) - async def shutdown(self) -> None: ... + for bank_data in stored_banks: + bank = VectorMemoryBankDef.model_validate_json(bank_data) + index = BankWithIndex( + bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) + ) + self.cache[bank.identifier] = index + + async def shutdown(self) -> None: + # Cleanup if needed + pass async def register_memory_bank( self, @@ -82,6 +100,14 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): memory_bank.type == MemoryBankType.vector.value ), f"Only vector banks are supported {memory_bank.type}" + # Store in kvstore + key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}" + await self.kvstore.set( + key=key, + value=memory_bank.json(), + ) + + # Store in cache index = BankWithIndex( bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) ) diff --git a/llama_stack/providers/impls/meta_reference/memory/tests/test_faiss.py b/llama_stack/providers/impls/meta_reference/memory/tests/test_faiss.py new file mode 100644 index 000000000..b09abc2ed --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/memory/tests/test_faiss.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import tempfile + +import pytest +from llama_stack.apis.memory import MemoryBankType, VectorMemoryBankDef +from llama_stack.providers.impls.meta_reference.memory.config import FaissImplConfig + +from llama_stack.providers.impls.meta_reference.memory.faiss import FaissMemoryImpl +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig + + +class TestFaissMemoryImpl: + @pytest.fixture + def faiss_impl(self): + # Create a temporary SQLite database file + temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + config = FaissImplConfig(kvstore=SqliteKVStoreConfig(db_path=temp_db.name)) + return FaissMemoryImpl(config) + + @pytest.mark.asyncio + async def test_initialize(self, faiss_impl): + # Test empty initialization + await faiss_impl.initialize() + assert len(faiss_impl.cache) == 0 + + # Test initialization with existing banks + bank = VectorMemoryBankDef( + identifier="test_bank", + type=MemoryBankType.vector.value, + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ) + + # Register a bank and reinitialize to test loading + await faiss_impl.register_memory_bank(bank) + + # Create new instance to test initialization with existing data + new_impl = FaissMemoryImpl(faiss_impl.config) + await new_impl.initialize() + + assert len(new_impl.cache) == 1 + assert "test_bank" in new_impl.cache + + @pytest.mark.asyncio + async def test_register_memory_bank(self, faiss_impl): + bank = VectorMemoryBankDef( + identifier="test_bank", + type=MemoryBankType.vector.value, + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ) + + await faiss_impl.initialize() + await faiss_impl.register_memory_bank(bank) + + assert "test_bank" in faiss_impl.cache + assert faiss_impl.cache["test_bank"].bank == bank + + # Verify persistence + new_impl = FaissMemoryImpl(faiss_impl.config) + await new_impl.initialize() + assert "test_bank" in new_impl.cache + + +if __name__ == "__main__": + pytest.main([__file__])