add support for provider update and unregister for memory banks

This commit is contained in:
Dinesh Yeduguru 2024-11-14 16:08:24 -08:00
parent 9b75e92852
commit e8b699797c
11 changed files with 240 additions and 65 deletions

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import json
import logging
from typing import Any, Dict, List, Optional
@ -37,10 +39,57 @@ class FaissIndex(EmbeddingIndex):
id_by_index: Dict[int, str]
chunk_by_index: Dict[int, str]
def __init__(self, dimension: int):
def __init__(self, dimension: int, kvstore=None, bank_id: str = None):
self.index = faiss.IndexFlatL2(dimension)
self.id_by_index = {}
self.chunk_by_index = {}
self.kvstore = kvstore
self.bank_id = bank_id
self.initialize()
async def initialize(self) -> None:
if not self.kvstore or not self.bank_id:
return
# Load existing index data from kvstore
index_key = f"faiss_index:v1::{self.bank_id}"
stored_data = await self.kvstore.get(index_key)
if stored_data:
data = json.loads(stored_data)
self.id_by_index = {int(k): v for k, v in data["id_by_index"].items()}
self.chunk_by_index = {
int(k): Chunk.model_validate_json(v)
for k, v in data["chunk_by_index"].items()
}
# Load FAISS index
index_bytes = base64.b64decode(data["faiss_index"])
self.index = faiss.deserialize_index(index_bytes)
async def _save_index(self):
if not self.kvstore or not self.bank_id:
return
# Serialize FAISS index
index_bytes = faiss.serialize_index(self.index)
# Prepare data for storage
data = {
"id_by_index": self.id_by_index,
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()},
"faiss_index": base64.b64encode(index_bytes).decode(),
}
# Store in kvstore
index_key = f"faiss_index:v1::{self.bank_id}"
await self.kvstore.set(key=index_key, value=json.dumps(data))
async def delete(self):
if not self.kvstore or not self.bank_id:
return
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
@tracing.span(name="add_chunks")
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
@ -51,6 +100,9 @@ class FaissIndex(EmbeddingIndex):
self.index.add(np.array(embeddings).astype(np.float32))
# Save updated index
await self._save_index()
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
@ -85,7 +137,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
for bank_data in stored_banks:
bank = VectorMemoryBank.model_validate_json(bank_data)
index = BankWithIndex(
bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore)
)
self.cache[bank.identifier] = index
@ -110,13 +162,28 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
# Store in cache
index = BankWithIndex(
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
bank=memory_bank,
index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore),
)
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBank]:
return [i.bank for i in self.cache.values()]
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}")
async def update_memory_bank(self, memory_bank: MemoryBank) -> None:
# Not possible to update the index in place, so we delete and recreate
await self.cache[memory_bank.identifier].index.delete()
self.cache[memory_bank.identifier] = BankWithIndex(
bank=memory_bank,
index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore),
)
async def insert_documents(
self,
bank_id: str,