feat: Adding bulk inserts to sqlite-vec

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-02-13 21:57:12 -05:00
parent 6b1773d530
commit 058668d667

View file

@ -71,30 +71,31 @@ class SQLiteVecIndex(EmbeddingIndex):
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
"""
Add new chunks along with their embeddings.
For each chunk, we insert its JSON into the metadata table and then insert its
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
If any insert fails, the transaction is rolled back to maintain consistency.
Add new chunks along with their embeddings using batch inserts.
First inserts all chunk metadata in a batch, then inserts all embeddings in a batch,
using the assigned rowids. If any insert fails, the transaction is rolled back.
"""
cur = self.connection.cursor()
try:
# Start transaction
cur.execute("BEGIN TRANSACTION")
for chunk, emb in zip(chunks, embeddings, strict=False):
# Serialize and insert the chunk metadata.
chunk_json = chunk.model_dump_json()
cur.execute(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", (chunk_json,))
row_id = cur.lastrowid
# Ensure the embedding is a list of floats.
emb_list = emb.tolist() if isinstance(emb, np.ndarray) else list(emb)
emb_blob = serialize_vector(emb_list)
cur.execute(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", (row_id, emb_blob))
chunk_data = [(chunk.model_dump_json(),) for chunk in chunks]
cur.executemany(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", chunk_data)
# Fetch the last N inserted row IDs
cur.execute(f"SELECT rowid FROM {self.metadata_table} ORDER BY rowid DESC LIMIT {len(chunks)}")
row_ids = [row[0] for row in cur.fetchall()]
row_ids.reverse() # Reverse to maintain the correct order of insertion
# Insert embeddings using the retrieved row IDs
embedding_data = [
(row_id, serialize_vector(emb.tolist() if isinstance(emb, np.ndarray) else list(emb)))
for row_id, emb in zip(row_ids, embeddings)
]
cur.executemany(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", embedding_data)
# Commit transaction if all inserts succeed
self.connection.commit()
except sqlite3.Error as e:
self.connection.rollback() # Rollback on failure
print(f"Error inserting into {self.vector_table} - error: {e}") # Log error (Consider using logging module)
logger.error(f"Error inserting into {self.vector_table}: {e}")
finally:
cur.close() # Ensure cursor is closed