From 57bafd0f8c61dcdff86701aeb2be40ef8175b953 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 15 Nov 2024 18:02:48 -0800 Subject: [PATCH] fix faiss serialize and serialize of index (#464) faiss serialize index returns a np object, that we first need to save to buffer and then write to sqllite. Since we are using json, we need to base64 encode the data. Same in the read path, we base64 decode and read into np array and then call into deserialize index. tests: torchrun $CONDA_PREFIX/bin/pytest -v -s -m "faiss" llama_stack/providers/tests/memory/test_memory.py Co-authored-by: Dinesh Yeduguru --- llama_stack/providers/inline/memory/faiss/faiss.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/llama_stack/providers/inline/memory/faiss/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py index 07c42d389..95791bc69 100644 --- a/llama_stack/providers/inline/memory/faiss/faiss.py +++ b/llama_stack/providers/inline/memory/faiss/faiss.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import base64 +import io import json import logging @@ -67,19 +68,20 @@ class FaissIndex(EmbeddingIndex): for k, v in data["chunk_by_index"].items() } - index_bytes = base64.b64decode(data["faiss_index"]) - self.index = faiss.deserialize_index(index_bytes) + buffer = io.BytesIO(base64.b64decode(data["faiss_index"])) + self.index = faiss.deserialize_index(np.loadtxt(buffer, dtype=np.uint8)) async def _save_index(self): if not self.kvstore or not self.bank_id: return - index_bytes = faiss.serialize_index(self.index) - + np_index = faiss.serialize_index(self.index) + buffer = io.BytesIO() + np.savetxt(buffer, np_index) data = { "id_by_index": self.id_by_index, "chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()}, - "faiss_index": base64.b64encode(index_bytes).decode(), + "faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"), } index_key = f"faiss_index:v1::{self.bank_id}" @@ -188,7 +190,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): ) -> None: index = self.cache.get(bank_id) if index is None: - raise ValueError(f"Bank {bank_id} not found") + raise ValueError(f"Bank {bank_id} not found. found: {self.cache.keys()}") await index.insert_documents(documents)