updating to do batch inserts

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-02-17 16:36:46 -05:00
parent d777789958
commit 663396ec5a
2 changed files with 29 additions and 17 deletions

View file

@ -18,6 +18,7 @@ from llama_stack.providers.utils.kvstore.config import (
class SQLiteVectorIOConfig(BaseModel):
db_path: str
kvstore: KVStoreConfig
batch_size: bool = 500
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:

View file

@ -52,7 +52,7 @@ class SQLiteVecIndex(EmbeddingIndex):
# Create the table to store chunk metadata.
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
id INTEGER PRIMARY KEY,
id TEXT PRIMARY KEY,
chunk TEXT
);
""")
@ -74,33 +74,44 @@ class SQLiteVecIndex(EmbeddingIndex):
Add new chunks along with their embeddings using batch inserts.
For each chunk, we insert its JSON into the metadata table and then insert its
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
If any insert fails, the transaction is rolled back.
If any insert fails, the transaction is rolled back to maintain consistency.
"""
cur = self.connection.cursor()
try:
# Start transaction
cur.execute("BEGIN TRANSACTION")
# Serialize and insert the chunk metadata.
chunk_data = [(chunk.model_dump_json(),) for chunk in chunks]
cur.executemany(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", chunk_data)
# Fetch the last *n* inserted row_ids -- note: this is more reliable than `row_id = cur.lastrowid`
cur.execute(f"SELECT rowid FROM {self.metadata_table} ORDER BY rowid DESC LIMIT {len(chunks)}")
row_ids = [row[0] for row in cur.fetchall()]
row_ids.reverse() # Reverse to maintain the correct order of insertion
# Insert embeddings using the retrieved row IDs
embedding_data = [
(row_id, serialize_vector(emb.tolist() if isinstance(emb, np.ndarray) else list(emb)))
for row_id, emb in zip(row_ids, embeddings, strict=True)
]
cur.executemany(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", embedding_data)
# Commit transaction if all inserts succeed
for i in range(0, len(chunks), self.config.batch_size):
batch_chunks = chunks[i : i + self.config.batch_size]
batch_embeddings = embeddings[i : i + self.config.batch_size]
# Prepare metadata inserts
metadata_data = [
(f"{chunk.metadata['document_id']}:chunk-{j}", chunk.model_dump_json())
for j, chunk in enumerate(batch_chunks)
]
# Insert metadata (ON CONFLICT to avoid duplicates)
cur.executemany(
f"""
INSERT INTO {self.metadata_table} (id, document)
VALUES (?, ?)
ON CONFLICT(id) DO UPDATE SET document = excluded.document;
""",
metadata_data,
)
# Prepare embeddings inserts
embedding_data = [
(f"{chunk.metadata['document_id']}:chunk-{j}", serialize_vector(emb.tolist()))
for j, (chunk, emb) in enumerate(zip(batch_chunks, batch_embeddings, strict=True))
]
# Insert embeddings in batch
cur.executemany(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?);", embedding_data)
self.connection.commit()
except sqlite3.Error as e:
self.connection.rollback() # Rollback on failure
logger.error(f"Error inserting into {self.vector_table}: {e}")
finally:
cur.close()
cur.close() # Ensure cursor is closed
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
"""