chore: Updating sqlite-vec to make non-blocking calls (#1762)

# What does this PR do?
This PR updates the sqlite-vec database calls to be non-blocking. Note
that each operation creates a new connection, which incurs some
performance overhead but is reasonable given [SQLite's threading and
connections constraints](https://www.sqlite.org/threadsafe.html).

Summary of changes:
- Refactored `SQLiteVecIndex` class to store database path instead of
connection object
- Added `_create_sqlite_connection()` helper function to create
connections on demand
- Ensured proper connection closure in all database operations
- Fixed test fixtures to use a file-based SQLite database for
thread-safety
- Updated the `SQLiteVecVectorIOAdapter` class to handle per-operation
connections

This PR helps chip away at
https://github.com/meta-llama/llama-stack/issues/1489

## Test Plan
sqlite-vec unit tests passed locally as well as a test script using the
client as a library.

## Misc

FYI @varshaprasad96 @kevincogan

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Arceo 2025-03-23 18:25:44 -06:00 committed by GitHub
parent 094eb6a5ae
commit 9e1ddf2b53
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 186 additions and 125 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import hashlib
import logging
import sqlite3
@ -29,6 +30,15 @@ def serialize_vector(vector: List[float]) -> bytes:
return struct.pack(f"{len(vector)}f", *vector)
def _create_sqlite_connection(db_path):
"""Create a SQLite connection with sqlite_vec extension loaded."""
connection = sqlite3.connect(db_path)
connection.enable_load_extension(True)
sqlite_vec.load(connection)
connection.enable_load_extension(False)
return connection
class SQLiteVecIndex(EmbeddingIndex):
"""
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
@ -37,21 +47,24 @@ class SQLiteVecIndex(EmbeddingIndex):
- A virtual table (vec_chunks_{bank_id}) that holds the serialized vector.
"""
def __init__(self, dimension: int, connection: sqlite3.Connection, bank_id: str):
def __init__(self, dimension: int, db_path: str, bank_id: str):
self.dimension = dimension
self.connection = connection
self.db_path = db_path
self.bank_id = bank_id
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
@classmethod
async def create(cls, dimension: int, connection: sqlite3.Connection, bank_id: str):
instance = cls(dimension, connection, bank_id)
async def create(cls, dimension: int, db_path: str, bank_id: str):
instance = cls(dimension, db_path, bank_id)
await instance.initialize()
return instance
async def initialize(self) -> None:
cur = self.connection.cursor()
def _init_tables():
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
try:
# Create the table to store chunk metadata.
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
@ -64,13 +77,26 @@ class SQLiteVecIndex(EmbeddingIndex):
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
""")
self.connection.commit()
connection.commit()
finally:
cur.close()
connection.close()
async def delete(self):
cur = self.connection.cursor()
await asyncio.to_thread(_init_tables)
async def delete(self) -> None:
def _drop_tables():
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
try:
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
self.connection.commit()
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_drop_tables)
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray, batch_size: int = 500):
"""
@ -81,9 +107,12 @@ class SQLiteVecIndex(EmbeddingIndex):
"""
assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks"
cur = self.connection.cursor()
def _execute_all_batch_inserts():
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
try:
# Start transaction
# Start transaction a single transcation for all batches
cur.execute("BEGIN TRANSACTION")
for i in range(0, len(chunks), batch_size):
batch_chunks = chunks[i : i + batch_size]
@ -105,20 +134,28 @@ class SQLiteVecIndex(EmbeddingIndex):
)
# Prepare embeddings inserts
embedding_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
(
generate_chunk_id(chunk.metadata["document_id"], chunk.content),
serialize_vector(emb.tolist()),
)
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
if isinstance(chunk.content, str)
]
# Insert embeddings in batch
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
self.connection.commit()
connection.commit()
except sqlite3.Error as e:
self.connection.rollback() # Rollback on failure
connection.rollback() # Rollback on failure
logger.error(f"Error inserting into {self.vector_table}: {e}")
raise
finally:
cur.close() # Ensure cursor is closed
cur.close()
connection.close()
# Process all batches in a single thread
await asyncio.to_thread(_execute_all_batch_inserts)
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
"""
@ -127,7 +164,12 @@ class SQLiteVecIndex(EmbeddingIndex):
"""
emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
emb_blob = serialize_vector(emb_list)
cur = self.connection.cursor()
def _execute_query():
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
try:
query_sql = f"""
SELECT m.id, m.chunk, v.distance
FROM {self.vector_table} AS v
@ -136,9 +178,14 @@ class SQLiteVecIndex(EmbeddingIndex):
ORDER BY v.distance;
"""
cur.execute(query_sql, (emb_blob, k))
rows = cur.fetchall()
chunks = []
scores = []
return cur.fetchall()
finally:
cur.close()
connection.close()
rows = await asyncio.to_thread(_execute_query)
chunks, scores = [], []
for _id, chunk_json, distance in rows:
try:
chunk = Chunk.model_validate_json(chunk_json)
@ -163,15 +210,13 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
self.config = config
self.inference_api = inference_api
self.cache: Dict[str, VectorDBWithIndex] = {}
self.connection: Optional[sqlite3.Connection] = None
async def initialize(self) -> None:
def _setup_connection():
# Open a connection to the SQLite database (the file is specified in the config).
self.connection = sqlite3.connect(self.config.db_path)
self.connection.enable_load_extension(True)
sqlite_vec.load(self.connection)
self.connection.enable_load_extension(False)
cur = self.connection.cursor()
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
# Create a table to persist vector DB registrations.
cur.execute("""
CREATE TABLE IF NOT EXISTS vector_dbs (
@ -179,47 +224,67 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
metadata TEXT
);
""")
self.connection.commit()
connection.commit()
# Load any existing vector DB registrations.
cur.execute("SELECT metadata FROM vector_dbs")
rows = cur.fetchall()
return rows
finally:
cur.close()
connection.close()
rows = await asyncio.to_thread(_setup_connection)
for row in rows:
vector_db_data = row[0]
vector_db = VectorDB.model_validate_json(vector_db_data)
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def shutdown(self) -> None:
if self.connection:
self.connection.close()
self.connection = None
# nothing to do since we don't maintain a persistent connection
pass
async def register_vector_db(self, vector_db: VectorDB) -> None:
if self.connection is None:
raise RuntimeError("SQLite connection not initialized")
cur = self.connection.cursor()
def _register_db():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)",
(vector_db.identifier, vector_db.model_dump_json()),
)
self.connection.commit()
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_register_db)
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def list_vector_dbs(self) -> List[VectorDB]:
return [v.vector_db for v in self.cache.values()]
async def unregister_vector_db(self, vector_db_id: str) -> None:
if self.connection is None:
raise RuntimeError("SQLite connection not initialized")
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
cur = self.connection.cursor()
def _delete_vector_db_from_registry():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
self.connection.commit()
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete_vector_db_from_registry)
async def insert_chunks(self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None) -> None:
if vector_db_id not in self.cache:

View file

@ -5,17 +5,16 @@
# the root directory of this source tree.
import asyncio
import sqlite3
import numpy as np
import pytest
import pytest_asyncio
import sqlite_vec
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
SQLiteVecIndex,
SQLiteVecVectorIOAdapter,
_create_sqlite_connection,
generate_chunk_id,
)
@ -36,29 +35,25 @@ def loop():
return asyncio.new_event_loop()
@pytest.fixture(scope="session", autouse=True)
def sqlite_connection(loop):
conn = sqlite3.connect(":memory:")
try:
conn.enable_load_extension(True)
sqlite_vec.load(conn)
yield conn
finally:
conn.close()
@pytest_asyncio.fixture(scope="session", autouse=True)
async def sqlite_vec_index(sqlite_connection, embedding_dimension):
return await SQLiteVecIndex.create(dimension=embedding_dimension, connection=sqlite_connection, bank_id="test_bank")
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_sqlite.db")
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
yield index
await index.delete()
@pytest.mark.asyncio
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
cur = sqlite_vec_index.connection.cursor()
connection = _create_sqlite_connection(sqlite_vec_index.db_path)
cur = connection.cursor()
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
count = cur.fetchone()[0]
assert count == len(sample_chunks)
cur.close()
connection.close()
@pytest.mark.asyncio
@ -79,13 +74,14 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime
sample_embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)
cur = sqlite_vec_index.connection.cursor()
connection = _create_sqlite_connection(sqlite_vec_index.db_path)
cur = connection.cursor()
# Retrieve all chunk IDs to check for duplicates
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
chunk_ids = [row[0] for row in cur.fetchall()]
cur.close()
connection.close()
# Ensure all chunk IDs are unique
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"