feat: Chunk sqlite-vec writes (#1094)

# What does this PR do?
1. This PR adds batch inserts into sqlite-vec as requested in
https://github.com/meta-llama/llama-stack/pull/1040
- Note: the inserts uses a uuid generated from the hash of the document
id and chunk content.
2. This PR also adds unit tests for sqlite-vec. In a follow up PR, I can
add similar tests to Faiss.

## Test Plan
1. Integration tests:
```python
INFERENCE_MODEL=llama3.2:3b-instruct-fp16 LLAMA_STACK_CONFIG=ollama pytest -s -v tests/client-sdk/vector_io/test_vector_io.py
...
PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_retrieve[all-MiniLM-L6-v2-sqlite_vec] PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_list PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_register[all-MiniLM-L6-v2-faiss] PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_register[all-MiniLM-L6-v2-sqlite_vec] PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_unregister[faiss] PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_unregister[sqlite_vec] PASSED
```
3. Unit tests:
```python
pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py  -v -s --tb=short --disable-warnings --asyncio-mode=auto
...
llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_add_chunks PASSED
llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_query_chunks PASSED
llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_chunk_id_conflict PASSED
llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_register_vector_db PASSED
llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_unregister_vector_db PASSED
llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_generate_chunk_id PASSED
```
I also tested using the same example RAG script in
https://github.com/meta-llama/llama-stack/pull/1040 and received the
output.

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Arceo 2025-02-19 20:07:46 -07:00 committed by GitHub
parent 26503ca1a4
commit 7972daa72e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 199 additions and 17 deletions

View file

@ -4,9 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import hashlib
import logging import logging
import sqlite3 import sqlite3
import struct import struct
import uuid
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import numpy as np import numpy as np
@ -52,14 +54,14 @@ class SQLiteVecIndex(EmbeddingIndex):
# Create the table to store chunk metadata. # Create the table to store chunk metadata.
cur.execute(f""" cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.metadata_table} ( CREATE TABLE IF NOT EXISTS {self.metadata_table} (
id INTEGER PRIMARY KEY, id TEXT PRIMARY KEY,
chunk TEXT chunk TEXT
); );
""") """)
# Create the virtual table for embeddings. # Create the virtual table for embeddings.
cur.execute(f""" cur.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table} CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
USING vec0(embedding FLOAT[{self.dimension}]); USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
""") """)
self.connection.commit() self.connection.commit()
@ -69,9 +71,9 @@ class SQLiteVecIndex(EmbeddingIndex):
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};") cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
self.connection.commit() self.connection.commit()
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray, batch_size: int = 500):
""" """
Add new chunks along with their embeddings. Add new chunks along with their embeddings using batch inserts.
For each chunk, we insert its JSON into the metadata table and then insert its For each chunk, we insert its JSON into the metadata table and then insert its
embedding (serialized to raw bytes) into the virtual table using the assigned rowid. embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
If any insert fails, the transaction is rolled back to maintain consistency. If any insert fails, the transaction is rolled back to maintain consistency.
@ -80,21 +82,35 @@ class SQLiteVecIndex(EmbeddingIndex):
try: try:
# Start transaction # Start transaction
cur.execute("BEGIN TRANSACTION") cur.execute("BEGIN TRANSACTION")
for chunk, emb in zip(chunks, embeddings, strict=False): for i in range(0, len(chunks), batch_size):
# Serialize and insert the chunk metadata. batch_chunks = chunks[i : i + batch_size]
chunk_json = chunk.model_dump_json() batch_embeddings = embeddings[i : i + batch_size]
cur.execute(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", (chunk_json,)) # Prepare metadata inserts
row_id = cur.lastrowid metadata_data = [
# Ensure the embedding is a list of floats. (generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
emb_list = emb.tolist() if isinstance(emb, np.ndarray) else list(emb) for chunk in batch_chunks
emb_blob = serialize_vector(emb_list) ]
cur.execute(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", (row_id, emb_blob)) # Insert metadata (ON CONFLICT to avoid duplicates)
# Commit transaction if all inserts succeed cur.executemany(
f"""
INSERT INTO {self.metadata_table} (id, chunk)
VALUES (?, ?)
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
""",
metadata_data,
)
# Prepare embeddings inserts
embedding_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
]
# Insert embeddings in batch
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
self.connection.commit() self.connection.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
self.connection.rollback() # Rollback on failure self.connection.rollback() # Rollback on failure
print(f"Error inserting into {self.vector_table} - error: {e}") # Log error (Consider using logging module) logger.error(f"Error inserting into {self.vector_table}: {e}")
finally: finally:
cur.close() # Ensure cursor is closed cur.close() # Ensure cursor is closed
@ -110,7 +126,7 @@ class SQLiteVecIndex(EmbeddingIndex):
query_sql = f""" query_sql = f"""
SELECT m.id, m.chunk, v.distance SELECT m.id, m.chunk, v.distance
FROM {self.vector_table} AS v FROM {self.vector_table} AS v
JOIN {self.metadata_table} AS m ON m.id = v.rowid JOIN {self.metadata_table} AS m ON m.id = v.id
WHERE v.embedding MATCH ? AND k = ? WHERE v.embedding MATCH ? AND k = ?
ORDER BY v.distance; ORDER BY v.distance;
""" """
@ -204,7 +220,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
if vector_db_id not in self.cache: if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}") raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api # The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
# and then call our indexs add_chunks. # and then call our index's add_chunks.
await self.cache[vector_db_id].insert_chunks(chunks) await self.cache[vector_db_id].insert_chunks(chunks)
async def query_chunks( async def query_chunks(
@ -213,3 +229,9 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
if vector_db_id not in self.cache: if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params) return await self.cache[vector_db_id].query_chunks(query, params)
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))

View file

@ -0,0 +1,160 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import sqlite3
import numpy as np
import pytest
import sqlite_vec
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
SQLiteVecIndex,
SQLiteVecVectorIOAdapter,
generate_chunk_id,
)
# How to run this test:
#
# pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
SQLITE_VEC_PROVIDER = "sqlite_vec"
EMBEDDING_DIMENSION = 384
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
@pytest.fixture(scope="session")
def loop():
return asyncio.new_event_loop()
@pytest.fixture(scope="session", autouse=True)
def sqlite_connection(loop):
conn = sqlite3.connect(":memory:")
try:
conn.enable_load_extension(True)
sqlite_vec.load(conn)
yield conn
finally:
conn.close()
@pytest.fixture(scope="session", autouse=True)
async def sqlite_vec_index(sqlite_connection):
return await SQLiteVecIndex.create(dimension=EMBEDDING_DIMENSION, connection=sqlite_connection, bank_id="test_bank")
@pytest.fixture(scope="session")
def sample_chunks():
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
n, k = 10, 3
sample = [
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
for j in range(k)
for i in range(n)
]
return sample
@pytest.fixture(scope="session")
def sample_embeddings(sample_chunks):
np.random.seed(42)
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
@pytest.mark.asyncio
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
cur = sqlite_vec_index.connection.cursor()
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
count = cur.fetchone()[0]
assert count == len(sample_chunks)
@pytest.mark.asyncio
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = np.random.rand(EMBEDDING_DIMENSION).astype(np.float32)
response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
@pytest.mark.asyncio
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks):
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
# Reduce batch size to force multiple batches for same document
# since there are 10 chunks per document and batch size is 2
batch_size = 2
sample_embeddings = np.random.rand(len(sample_chunks), EMBEDDING_DIMENSION).astype(np.float32)
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)
cur = sqlite_vec_index.connection.cursor()
# Retrieve all chunk IDs to check for duplicates
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
chunk_ids = [row[0] for row in cur.fetchall()]
cur.close()
# Ensure all chunk IDs are unique
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
@pytest.fixture(scope="session")
async def sqlite_vec_adapter(sqlite_connection):
config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database
adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None)
await adapter.initialize()
yield adapter
await adapter.shutdown()
@pytest.mark.asyncio
async def test_register_vector_db(sqlite_vec_adapter):
vector_db = VectorDB(
identifier="test_db",
embedding_model=EMBEDDING_MODEL,
embedding_dimension=EMBEDDING_DIMENSION,
metadata={},
provider_id=SQLITE_VEC_PROVIDER,
)
await sqlite_vec_adapter.register_vector_db(vector_db)
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
assert any(db.identifier == "test_db" for db in vector_dbs)
@pytest.mark.asyncio
async def test_unregister_vector_db(sqlite_vec_adapter):
vector_db = VectorDB(
identifier="test_db",
embedding_model=EMBEDDING_MODEL,
embedding_dimension=EMBEDDING_DIMENSION,
metadata={},
provider_id=SQLITE_VEC_PROVIDER,
)
await sqlite_vec_adapter.register_vector_db(vector_db)
await sqlite_vec_adapter.unregister_vector_db("test_db")
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
assert not any(db.identifier == "test_db" for db in vector_dbs)
def test_generate_chunk_id():
chunks = [
Chunk(content="test", metadata={"document_id": "doc-1"}),
Chunk(content="test ", metadata={"document_id": "doc-1"}),
Chunk(content="test 3", metadata={"document_id": "doc-1"}),
]
chunk_ids = sorted([generate_chunk_id(chunk.metadata["document_id"], chunk.content) for chunk in chunks])
assert chunk_ids == [
"177a1368-f6a8-0c50-6e92-18677f2c3de3",
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
"f68df25d-d9aa-ab4d-5684-64a233add20d",
]