mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
updating to do batch inserts
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
d777789958
commit
663396ec5a
2 changed files with 29 additions and 17 deletions
|
@ -18,6 +18,7 @@ from llama_stack.providers.utils.kvstore.config import (
|
|||
class SQLiteVectorIOConfig(BaseModel):
|
||||
db_path: str
|
||||
kvstore: KVStoreConfig
|
||||
batch_size: bool = 500
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
|
|
|
@ -52,7 +52,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
# Create the table to store chunk metadata.
|
||||
cur.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
|
||||
id INTEGER PRIMARY KEY,
|
||||
id TEXT PRIMARY KEY,
|
||||
chunk TEXT
|
||||
);
|
||||
""")
|
||||
|
@ -74,33 +74,44 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
Add new chunks along with their embeddings using batch inserts.
|
||||
For each chunk, we insert its JSON into the metadata table and then insert its
|
||||
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
||||
If any insert fails, the transaction is rolled back.
|
||||
If any insert fails, the transaction is rolled back to maintain consistency.
|
||||
"""
|
||||
cur = self.connection.cursor()
|
||||
try:
|
||||
# Start transaction
|
||||
cur.execute("BEGIN TRANSACTION")
|
||||
# Serialize and insert the chunk metadata.
|
||||
chunk_data = [(chunk.model_dump_json(),) for chunk in chunks]
|
||||
cur.executemany(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", chunk_data)
|
||||
# Fetch the last *n* inserted row_ids -- note: this is more reliable than `row_id = cur.lastrowid`
|
||||
cur.execute(f"SELECT rowid FROM {self.metadata_table} ORDER BY rowid DESC LIMIT {len(chunks)}")
|
||||
row_ids = [row[0] for row in cur.fetchall()]
|
||||
row_ids.reverse() # Reverse to maintain the correct order of insertion
|
||||
# Insert embeddings using the retrieved row IDs
|
||||
embedding_data = [
|
||||
(row_id, serialize_vector(emb.tolist() if isinstance(emb, np.ndarray) else list(emb)))
|
||||
for row_id, emb in zip(row_ids, embeddings, strict=True)
|
||||
]
|
||||
cur.executemany(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", embedding_data)
|
||||
# Commit transaction if all inserts succeed
|
||||
for i in range(0, len(chunks), self.config.batch_size):
|
||||
batch_chunks = chunks[i : i + self.config.batch_size]
|
||||
batch_embeddings = embeddings[i : i + self.config.batch_size]
|
||||
# Prepare metadata inserts
|
||||
metadata_data = [
|
||||
(f"{chunk.metadata['document_id']}:chunk-{j}", chunk.model_dump_json())
|
||||
for j, chunk in enumerate(batch_chunks)
|
||||
]
|
||||
# Insert metadata (ON CONFLICT to avoid duplicates)
|
||||
cur.executemany(
|
||||
f"""
|
||||
INSERT INTO {self.metadata_table} (id, document)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET document = excluded.document;
|
||||
""",
|
||||
metadata_data,
|
||||
)
|
||||
# Prepare embeddings inserts
|
||||
embedding_data = [
|
||||
(f"{chunk.metadata['document_id']}:chunk-{j}", serialize_vector(emb.tolist()))
|
||||
for j, (chunk, emb) in enumerate(zip(batch_chunks, batch_embeddings, strict=True))
|
||||
]
|
||||
# Insert embeddings in batch
|
||||
cur.executemany(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?);", embedding_data)
|
||||
self.connection.commit()
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self.connection.rollback() # Rollback on failure
|
||||
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
||||
|
||||
finally:
|
||||
cur.close()
|
||||
cur.close() # Ensure cursor is closed
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue