mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
updating to do batch inserts
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
d777789958
commit
663396ec5a
2 changed files with 29 additions and 17 deletions
|
@ -18,6 +18,7 @@ from llama_stack.providers.utils.kvstore.config import (
|
||||||
class SQLiteVectorIOConfig(BaseModel):
|
class SQLiteVectorIOConfig(BaseModel):
|
||||||
db_path: str
|
db_path: str
|
||||||
kvstore: KVStoreConfig
|
kvstore: KVStoreConfig
|
||||||
|
batch_size: bool = 500
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||||
|
|
|
@ -52,7 +52,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
# Create the table to store chunk metadata.
|
# Create the table to store chunk metadata.
|
||||||
cur.execute(f"""
|
cur.execute(f"""
|
||||||
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
|
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
|
||||||
id INTEGER PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
chunk TEXT
|
chunk TEXT
|
||||||
);
|
);
|
||||||
""")
|
""")
|
||||||
|
@ -74,33 +74,44 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
Add new chunks along with their embeddings using batch inserts.
|
Add new chunks along with their embeddings using batch inserts.
|
||||||
For each chunk, we insert its JSON into the metadata table and then insert its
|
For each chunk, we insert its JSON into the metadata table and then insert its
|
||||||
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
||||||
If any insert fails, the transaction is rolled back.
|
If any insert fails, the transaction is rolled back to maintain consistency.
|
||||||
"""
|
"""
|
||||||
cur = self.connection.cursor()
|
cur = self.connection.cursor()
|
||||||
try:
|
try:
|
||||||
# Start transaction
|
# Start transaction
|
||||||
cur.execute("BEGIN TRANSACTION")
|
cur.execute("BEGIN TRANSACTION")
|
||||||
# Serialize and insert the chunk metadata.
|
for i in range(0, len(chunks), self.config.batch_size):
|
||||||
chunk_data = [(chunk.model_dump_json(),) for chunk in chunks]
|
batch_chunks = chunks[i : i + self.config.batch_size]
|
||||||
cur.executemany(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", chunk_data)
|
batch_embeddings = embeddings[i : i + self.config.batch_size]
|
||||||
# Fetch the last *n* inserted row_ids -- note: this is more reliable than `row_id = cur.lastrowid`
|
# Prepare metadata inserts
|
||||||
cur.execute(f"SELECT rowid FROM {self.metadata_table} ORDER BY rowid DESC LIMIT {len(chunks)}")
|
metadata_data = [
|
||||||
row_ids = [row[0] for row in cur.fetchall()]
|
(f"{chunk.metadata['document_id']}:chunk-{j}", chunk.model_dump_json())
|
||||||
row_ids.reverse() # Reverse to maintain the correct order of insertion
|
for j, chunk in enumerate(batch_chunks)
|
||||||
# Insert embeddings using the retrieved row IDs
|
]
|
||||||
embedding_data = [
|
# Insert metadata (ON CONFLICT to avoid duplicates)
|
||||||
(row_id, serialize_vector(emb.tolist() if isinstance(emb, np.ndarray) else list(emb)))
|
cur.executemany(
|
||||||
for row_id, emb in zip(row_ids, embeddings, strict=True)
|
f"""
|
||||||
]
|
INSERT INTO {self.metadata_table} (id, document)
|
||||||
cur.executemany(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", embedding_data)
|
VALUES (?, ?)
|
||||||
# Commit transaction if all inserts succeed
|
ON CONFLICT(id) DO UPDATE SET document = excluded.document;
|
||||||
|
""",
|
||||||
|
metadata_data,
|
||||||
|
)
|
||||||
|
# Prepare embeddings inserts
|
||||||
|
embedding_data = [
|
||||||
|
(f"{chunk.metadata['document_id']}:chunk-{j}", serialize_vector(emb.tolist()))
|
||||||
|
for j, (chunk, emb) in enumerate(zip(batch_chunks, batch_embeddings, strict=True))
|
||||||
|
]
|
||||||
|
# Insert embeddings in batch
|
||||||
|
cur.executemany(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?);", embedding_data)
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self.connection.rollback() # Rollback on failure
|
self.connection.rollback() # Rollback on failure
|
||||||
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
cur.close()
|
cur.close() # Ensure cursor is closed
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue