mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-08 00:11:30 +00:00
feat: Adding bulk inserts to sqlite-vec
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
6b1773d530
commit
058668d667
1 changed files with 16 additions and 15 deletions
|
|
@ -71,30 +71,31 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
"""
|
||||
Add new chunks along with their embeddings.
|
||||
For each chunk, we insert its JSON into the metadata table and then insert its
|
||||
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
||||
If any insert fails, the transaction is rolled back to maintain consistency.
|
||||
Add new chunks along with their embeddings using batch inserts.
|
||||
First inserts all chunk metadata in a batch, then inserts all embeddings in a batch,
|
||||
using the assigned rowids. If any insert fails, the transaction is rolled back.
|
||||
"""
|
||||
cur = self.connection.cursor()
|
||||
try:
|
||||
# Start transaction
|
||||
cur.execute("BEGIN TRANSACTION")
|
||||
for chunk, emb in zip(chunks, embeddings, strict=False):
|
||||
# Serialize and insert the chunk metadata.
|
||||
chunk_json = chunk.model_dump_json()
|
||||
cur.execute(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", (chunk_json,))
|
||||
row_id = cur.lastrowid
|
||||
# Ensure the embedding is a list of floats.
|
||||
emb_list = emb.tolist() if isinstance(emb, np.ndarray) else list(emb)
|
||||
emb_blob = serialize_vector(emb_list)
|
||||
cur.execute(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", (row_id, emb_blob))
|
||||
chunk_data = [(chunk.model_dump_json(),) for chunk in chunks]
|
||||
cur.executemany(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", chunk_data)
|
||||
# Fetch the last N inserted row IDs
|
||||
cur.execute(f"SELECT rowid FROM {self.metadata_table} ORDER BY rowid DESC LIMIT {len(chunks)}")
|
||||
row_ids = [row[0] for row in cur.fetchall()]
|
||||
row_ids.reverse() # Reverse to maintain the correct order of insertion
|
||||
# Insert embeddings using the retrieved row IDs
|
||||
embedding_data = [
|
||||
(row_id, serialize_vector(emb.tolist() if isinstance(emb, np.ndarray) else list(emb)))
|
||||
for row_id, emb in zip(row_ids, embeddings)
|
||||
]
|
||||
cur.executemany(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", embedding_data)
|
||||
# Commit transaction if all inserts succeed
|
||||
self.connection.commit()
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self.connection.rollback() # Rollback on failure
|
||||
print(f"Error inserting into {self.vector_table} - error: {e}") # Log error (Consider using logging module)
|
||||
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
||||
|
||||
finally:
|
||||
cur.close() # Ensure cursor is closed
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue