fix faiss serialize and serialize of index (#464)

faiss serialize index returns a np object, that we first need to save to
buffer and then write to sqllite. Since we are using json, we need to
base64 encode the data.

Same in the read path, we base64 decode and read into np array and then
call into deserialize index.

tests:
torchrun $CONDA_PREFIX/bin/pytest -v -s -m "faiss"
llama_stack/providers/tests/memory/test_memory.py

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-15 18:02:48 -08:00 committed by GitHub
parent ff99025875
commit 57bafd0f8c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import base64 import base64
import io
import json import json
import logging import logging
@ -67,19 +68,20 @@ class FaissIndex(EmbeddingIndex):
for k, v in data["chunk_by_index"].items() for k, v in data["chunk_by_index"].items()
} }
index_bytes = base64.b64decode(data["faiss_index"]) buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
self.index = faiss.deserialize_index(index_bytes) self.index = faiss.deserialize_index(np.loadtxt(buffer, dtype=np.uint8))
async def _save_index(self): async def _save_index(self):
if not self.kvstore or not self.bank_id: if not self.kvstore or not self.bank_id:
return return
index_bytes = faiss.serialize_index(self.index) np_index = faiss.serialize_index(self.index)
buffer = io.BytesIO()
np.savetxt(buffer, np_index)
data = { data = {
"id_by_index": self.id_by_index, "id_by_index": self.id_by_index,
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()}, "chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()},
"faiss_index": base64.b64encode(index_bytes).decode(), "faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
} }
index_key = f"faiss_index:v1::{self.bank_id}" index_key = f"faiss_index:v1::{self.bank_id}"
@ -188,7 +190,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
) -> None: ) -> None:
index = self.cache.get(bank_id) index = self.cache.get(bank_id)
if index is None: if index is None:
raise ValueError(f"Bank {bank_id} not found") raise ValueError(f"Bank {bank_id} not found. found: {self.cache.keys()}")
await index.insert_documents(documents) await index.insert_documents(documents)