feat: Chunk sqlite-vec writes (#1094)

# What does this PR do?
1. This PR adds batch inserts into sqlite-vec as requested in
https://github.com/meta-llama/llama-stack/pull/1040
- Note: the inserts uses a uuid generated from the hash of the document
id and chunk content.
2. This PR also adds unit tests for sqlite-vec. In a follow up PR, I can
add similar tests to Faiss.

## Test Plan
1. Integration tests:
```python
INFERENCE_MODEL=llama3.2:3b-instruct-fp16 LLAMA_STACK_CONFIG=ollama pytest -s -v tests/client-sdk/vector_io/test_vector_io.py
...
PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_retrieve[all-MiniLM-L6-v2-sqlite_vec] PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_list PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_register[all-MiniLM-L6-v2-faiss] PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_register[all-MiniLM-L6-v2-sqlite_vec] PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_unregister[faiss] PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_unregister[sqlite_vec] PASSED
```
3. Unit tests:
```python
pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py  -v -s --tb=short --disable-warnings --asyncio-mode=auto
...
llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_add_chunks PASSED
llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_query_chunks PASSED
llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_chunk_id_conflict PASSED
llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_register_vector_db PASSED
llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_unregister_vector_db PASSED
llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_generate_chunk_id PASSED
```
I also tested using the same example RAG script in
https://github.com/meta-llama/llama-stack/pull/1040 and received the
output.

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Arceo 2025-02-19 20:07:46 -07:00 committed by GitHub
parent 26503ca1a4
commit 7972daa72e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 199 additions and 17 deletions

View file

@ -4,9 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import hashlib
import logging
import sqlite3
import struct
import uuid
from typing import Any, Dict, List, Optional
import numpy as np
@ -52,14 +54,14 @@ class SQLiteVecIndex(EmbeddingIndex):
# Create the table to store chunk metadata.
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
id INTEGER PRIMARY KEY,
id TEXT PRIMARY KEY,
chunk TEXT
);
""")
# Create the virtual table for embeddings.
cur.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
USING vec0(embedding FLOAT[{self.dimension}]);
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
""")
self.connection.commit()
@ -69,9 +71,9 @@ class SQLiteVecIndex(EmbeddingIndex):
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
self.connection.commit()
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray, batch_size: int = 500):
"""
Add new chunks along with their embeddings.
Add new chunks along with their embeddings using batch inserts.
For each chunk, we insert its JSON into the metadata table and then insert its
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
If any insert fails, the transaction is rolled back to maintain consistency.
@ -80,21 +82,35 @@ class SQLiteVecIndex(EmbeddingIndex):
try:
# Start transaction
cur.execute("BEGIN TRANSACTION")
for chunk, emb in zip(chunks, embeddings, strict=False):
# Serialize and insert the chunk metadata.
chunk_json = chunk.model_dump_json()
cur.execute(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", (chunk_json,))
row_id = cur.lastrowid
# Ensure the embedding is a list of floats.
emb_list = emb.tolist() if isinstance(emb, np.ndarray) else list(emb)
emb_blob = serialize_vector(emb_list)
cur.execute(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", (row_id, emb_blob))
# Commit transaction if all inserts succeed
for i in range(0, len(chunks), batch_size):
batch_chunks = chunks[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size]
# Prepare metadata inserts
metadata_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
for chunk in batch_chunks
]
# Insert metadata (ON CONFLICT to avoid duplicates)
cur.executemany(
f"""
INSERT INTO {self.metadata_table} (id, chunk)
VALUES (?, ?)
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
""",
metadata_data,
)
# Prepare embeddings inserts
embedding_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
]
# Insert embeddings in batch
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
self.connection.commit()
except sqlite3.Error as e:
self.connection.rollback() # Rollback on failure
print(f"Error inserting into {self.vector_table} - error: {e}") # Log error (Consider using logging module)
logger.error(f"Error inserting into {self.vector_table}: {e}")
finally:
cur.close() # Ensure cursor is closed
@ -110,7 +126,7 @@ class SQLiteVecIndex(EmbeddingIndex):
query_sql = f"""
SELECT m.id, m.chunk, v.distance
FROM {self.vector_table} AS v
JOIN {self.metadata_table} AS m ON m.id = v.rowid
JOIN {self.metadata_table} AS m ON m.id = v.id
WHERE v.embedding MATCH ? AND k = ?
ORDER BY v.distance;
"""
@ -204,7 +220,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
# and then call our indexs add_chunks.
# and then call our index's add_chunks.
await self.cache[vector_db_id].insert_chunks(chunks)
async def query_chunks(
@ -213,3 +229,9 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params)
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))