diff --git a/llama_stack/providers/impls/meta_reference/memory/faiss.py b/llama_stack/providers/impls/meta_reference/memory/faiss.py index afeb7207b..4bd5fd5a7 100644 --- a/llama_stack/providers/impls/meta_reference/memory/faiss.py +++ b/llama_stack/providers/impls/meta_reference/memory/faiss.py @@ -16,6 +16,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403 from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.memory.vector_store import ( ALL_MINILM_L6_V2_DIMENSION, @@ -25,7 +26,6 @@ from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.telemetry import tracing from .config import FaissImplConfig -from llama_stack.providers.utils.kvstore import kvstore_impl logger = logging.getLogger(__name__) @@ -73,19 +73,18 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): self.config = config self.cache = {} self.kvstore = None - + async def initialize(self) -> None: self.kvstore = await kvstore_impl(self.config.kvstore) # Load existing banks from kvstore start_key = MEMORY_BANKS_PREFIX end_key = f"{MEMORY_BANKS_PREFIX}\xff" stored_banks = await self.kvstore.range(start_key, end_key) - + for bank_data in stored_banks: bank = VectorMemoryBankDef.model_validate_json(bank_data) index = BankWithIndex( - bank=bank, - index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) + bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) ) self.cache[bank.identifier] = index @@ -110,8 +109,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): # Store in cache index = BankWithIndex( - bank=memory_bank, - index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) + bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) ) self.cache[memory_bank.identifier] = index