updating to do batch inserts

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-02-17 16:36:46 -05:00
parent d777789958
commit 663396ec5a
2 changed files with 29 additions and 17 deletions

View file

@ -18,6 +18,7 @@ from llama_stack.providers.utils.kvstore.config import (
class SQLiteVectorIOConfig(BaseModel): class SQLiteVectorIOConfig(BaseModel):
db_path: str db_path: str
kvstore: KVStoreConfig kvstore: KVStoreConfig
batch_size: bool = 500
@classmethod @classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]: def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:

View file

@ -52,7 +52,7 @@ class SQLiteVecIndex(EmbeddingIndex):
# Create the table to store chunk metadata. # Create the table to store chunk metadata.
cur.execute(f""" cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.metadata_table} ( CREATE TABLE IF NOT EXISTS {self.metadata_table} (
id INTEGER PRIMARY KEY, id TEXT PRIMARY KEY,
chunk TEXT chunk TEXT
); );
""") """)
@ -74,33 +74,44 @@ class SQLiteVecIndex(EmbeddingIndex):
Add new chunks along with their embeddings using batch inserts. Add new chunks along with their embeddings using batch inserts.
For each chunk, we insert its JSON into the metadata table and then insert its For each chunk, we insert its JSON into the metadata table and then insert its
embedding (serialized to raw bytes) into the virtual table using the assigned rowid. embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
If any insert fails, the transaction is rolled back. If any insert fails, the transaction is rolled back to maintain consistency.
""" """
cur = self.connection.cursor() cur = self.connection.cursor()
try: try:
# Start transaction # Start transaction
cur.execute("BEGIN TRANSACTION") cur.execute("BEGIN TRANSACTION")
# Serialize and insert the chunk metadata. for i in range(0, len(chunks), self.config.batch_size):
chunk_data = [(chunk.model_dump_json(),) for chunk in chunks] batch_chunks = chunks[i : i + self.config.batch_size]
cur.executemany(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", chunk_data) batch_embeddings = embeddings[i : i + self.config.batch_size]
# Fetch the last *n* inserted row_ids -- note: this is more reliable than `row_id = cur.lastrowid` # Prepare metadata inserts
cur.execute(f"SELECT rowid FROM {self.metadata_table} ORDER BY rowid DESC LIMIT {len(chunks)}") metadata_data = [
row_ids = [row[0] for row in cur.fetchall()] (f"{chunk.metadata['document_id']}:chunk-{j}", chunk.model_dump_json())
row_ids.reverse() # Reverse to maintain the correct order of insertion for j, chunk in enumerate(batch_chunks)
# Insert embeddings using the retrieved row IDs ]
embedding_data = [ # Insert metadata (ON CONFLICT to avoid duplicates)
(row_id, serialize_vector(emb.tolist() if isinstance(emb, np.ndarray) else list(emb))) cur.executemany(
for row_id, emb in zip(row_ids, embeddings, strict=True) f"""
] INSERT INTO {self.metadata_table} (id, document)
cur.executemany(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", embedding_data) VALUES (?, ?)
# Commit transaction if all inserts succeed ON CONFLICT(id) DO UPDATE SET document = excluded.document;
""",
metadata_data,
)
# Prepare embeddings inserts
embedding_data = [
(f"{chunk.metadata['document_id']}:chunk-{j}", serialize_vector(emb.tolist()))
for j, (chunk, emb) in enumerate(zip(batch_chunks, batch_embeddings, strict=True))
]
# Insert embeddings in batch
cur.executemany(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?);", embedding_data)
self.connection.commit() self.connection.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
self.connection.rollback() # Rollback on failure self.connection.rollback() # Rollback on failure
logger.error(f"Error inserting into {self.vector_table}: {e}") logger.error(f"Error inserting into {self.vector_table}: {e}")
finally: finally:
cur.close() cur.close() # Ensure cursor is closed
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
""" """