await initialize in faiss (#463)

tests:
```
 torchrun $CONDA_PREFIX/bin/pytest -v -s -m "faiss" llama_stack/providers/tests/memory/test_memory.py
```

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-15 14:21:31 -08:00 committed by GitHub
parent 20bf2f50c2
commit ff99025875
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -45,7 +45,12 @@ class FaissIndex(EmbeddingIndex):
self.chunk_by_index = {} self.chunk_by_index = {}
self.kvstore = kvstore self.kvstore = kvstore
self.bank_id = bank_id self.bank_id = bank_id
self.initialize()
@classmethod
async def create(cls, dimension: int, kvstore=None, bank_id: str = None):
instance = cls(dimension, kvstore, bank_id)
await instance.initialize()
return instance
async def initialize(self) -> None: async def initialize(self) -> None:
if not self.kvstore: if not self.kvstore:
@ -132,7 +137,10 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
for bank_data in stored_banks: for bank_data in stored_banks:
bank = VectorMemoryBank.model_validate_json(bank_data) bank = VectorMemoryBank.model_validate_json(bank_data)
index = BankWithIndex( index = BankWithIndex(
bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore) bank=bank,
index=await FaissIndex.create(
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier
),
) )
self.cache[bank.identifier] = index self.cache[bank.identifier] = index
@ -158,7 +166,9 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
# Store in cache # Store in cache
index = BankWithIndex( index = BankWithIndex(
bank=memory_bank, bank=memory_bank,
index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore), index=await FaissIndex.create(
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier
),
) )
self.cache[memory_bank.identifier] = index self.cache[memory_bank.identifier] = index