forked from phoenix-oss/llama-stack-mirror
fix faiss serialize and serialize of index (#464)
faiss serialize index returns a np object, that we first need to save to buffer and then write to sqllite. Since we are using json, we need to base64 encode the data. Same in the read path, we base64 decode and read into np array and then call into deserialize index. tests: torchrun $CONDA_PREFIX/bin/pytest -v -s -m "faiss" llama_stack/providers/tests/memory/test_memory.py Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
ff99025875
commit
57bafd0f8c
1 changed files with 8 additions and 6 deletions
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
@ -67,19 +68,20 @@ class FaissIndex(EmbeddingIndex):
|
|||
for k, v in data["chunk_by_index"].items()
|
||||
}
|
||||
|
||||
index_bytes = base64.b64decode(data["faiss_index"])
|
||||
self.index = faiss.deserialize_index(index_bytes)
|
||||
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
|
||||
self.index = faiss.deserialize_index(np.loadtxt(buffer, dtype=np.uint8))
|
||||
|
||||
async def _save_index(self):
|
||||
if not self.kvstore or not self.bank_id:
|
||||
return
|
||||
|
||||
index_bytes = faiss.serialize_index(self.index)
|
||||
|
||||
np_index = faiss.serialize_index(self.index)
|
||||
buffer = io.BytesIO()
|
||||
np.savetxt(buffer, np_index)
|
||||
data = {
|
||||
"id_by_index": self.id_by_index,
|
||||
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()},
|
||||
"faiss_index": base64.b64encode(index_bytes).decode(),
|
||||
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
|
||||
}
|
||||
|
||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
||||
|
@ -188,7 +190,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
) -> None:
|
||||
index = self.cache.get(bank_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
raise ValueError(f"Bank {bank_id} not found. found: {self.cache.keys()}")
|
||||
|
||||
await index.insert_documents(documents)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue