await initialize in faiss

This commit is contained in:
Dinesh Yeduguru 2024-11-15 12:32:40 -08:00
parent 20bf2f50c2
commit 4db7c4d909

View file

@ -45,7 +45,12 @@ class FaissIndex(EmbeddingIndex):
self.chunk_by_index = {}
self.kvstore = kvstore
self.bank_id = bank_id
self.initialize()
@classmethod
async def create(cls, dimension: int, kvstore=None, bank_id: str = None):
instance = cls(dimension, kvstore, bank_id)
await instance.initialize()
return instance
async def initialize(self) -> None:
if not self.kvstore:
@ -132,7 +137,10 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
for bank_data in stored_banks:
bank = VectorMemoryBank.model_validate_json(bank_data)
index = BankWithIndex(
bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore)
bank=bank,
index=await FaissIndex.create(
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier
),
)
self.cache[bank.identifier] = index
@ -158,7 +166,9 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
# Store in cache
index = BankWithIndex(
bank=memory_bank,
index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore),
index=await FaissIndex.create(
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier
),
)
self.cache[memory_bank.identifier] = index