forked from phoenix-oss/llama-stack-mirror
await initialize in faiss (#463)
tests: ``` torchrun $CONDA_PREFIX/bin/pytest -v -s -m "faiss" llama_stack/providers/tests/memory/test_memory.py ``` Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
20bf2f50c2
commit
ff99025875
1 changed files with 13 additions and 3 deletions
|
@ -45,7 +45,12 @@ class FaissIndex(EmbeddingIndex):
|
||||||
self.chunk_by_index = {}
|
self.chunk_by_index = {}
|
||||||
self.kvstore = kvstore
|
self.kvstore = kvstore
|
||||||
self.bank_id = bank_id
|
self.bank_id = bank_id
|
||||||
self.initialize()
|
|
||||||
|
@classmethod
|
||||||
|
async def create(cls, dimension: int, kvstore=None, bank_id: str = None):
|
||||||
|
instance = cls(dimension, kvstore, bank_id)
|
||||||
|
await instance.initialize()
|
||||||
|
return instance
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
if not self.kvstore:
|
if not self.kvstore:
|
||||||
|
@ -132,7 +137,10 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
for bank_data in stored_banks:
|
for bank_data in stored_banks:
|
||||||
bank = VectorMemoryBank.model_validate_json(bank_data)
|
bank = VectorMemoryBank.model_validate_json(bank_data)
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore)
|
bank=bank,
|
||||||
|
index=await FaissIndex.create(
|
||||||
|
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.cache[bank.identifier] = index
|
self.cache[bank.identifier] = index
|
||||||
|
|
||||||
|
@ -158,7 +166,9 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
# Store in cache
|
# Store in cache
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=memory_bank,
|
bank=memory_bank,
|
||||||
index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore),
|
index=await FaissIndex.create(
|
||||||
|
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = index
|
self.cache[memory_bank.identifier] = index
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue