From 9e1ddf2b538d24d6675a3f9e5310fb5de665906d Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Sun, 23 Mar 2025 18:25:44 -0600 Subject: [PATCH] chore: Updating sqlite-vec to make non-blocking calls (#1762) # What does this PR do? This PR updates the sqlite-vec database calls to be non-blocking. Note that each operation creates a new connection, which incurs some performance overhead but is reasonable given [SQLite's threading and connections constraints](https://www.sqlite.org/threadsafe.html). Summary of changes: - Refactored `SQLiteVecIndex` class to store database path instead of connection object - Added `_create_sqlite_connection()` helper function to create connections on demand - Ensured proper connection closure in all database operations - Fixed test fixtures to use a file-based SQLite database for thread-safety - Updated the `SQLiteVecVectorIOAdapter` class to handle per-operation connections This PR helps chip away at https://github.com/meta-llama/llama-stack/issues/1489 ## Test Plan sqlite-vec unit tests passed locally as well as a test script using the client as a library. ## Misc FYI @varshaprasad96 @kevincogan Signed-off-by: Francisco Javier Arceo --- .../inline/vector_io/sqlite_vec/sqlite_vec.py | 279 +++++++++++------- .../providers/vector_io/test_sqlite_vec.py | 32 +- 2 files changed, 186 insertions(+), 125 deletions(-) diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index b8f6f602f..5f7671138 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import hashlib import logging import sqlite3 @@ -29,6 +30,15 @@ def serialize_vector(vector: List[float]) -> bytes: return struct.pack(f"{len(vector)}f", *vector) +def _create_sqlite_connection(db_path): + """Create a SQLite connection with sqlite_vec extension loaded.""" + connection = sqlite3.connect(db_path) + connection.enable_load_extension(True) + sqlite_vec.load(connection) + connection.enable_load_extension(False) + return connection + + class SQLiteVecIndex(EmbeddingIndex): """ An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec. @@ -37,40 +47,56 @@ class SQLiteVecIndex(EmbeddingIndex): - A virtual table (vec_chunks_{bank_id}) that holds the serialized vector. """ - def __init__(self, dimension: int, connection: sqlite3.Connection, bank_id: str): + def __init__(self, dimension: int, db_path: str, bank_id: str): self.dimension = dimension - self.connection = connection + self.db_path = db_path self.bank_id = bank_id self.metadata_table = f"chunks_{bank_id}".replace("-", "_") self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_") @classmethod - async def create(cls, dimension: int, connection: sqlite3.Connection, bank_id: str): - instance = cls(dimension, connection, bank_id) + async def create(cls, dimension: int, db_path: str, bank_id: str): + instance = cls(dimension, db_path, bank_id) await instance.initialize() return instance async def initialize(self) -> None: - cur = self.connection.cursor() - # Create the table to store chunk metadata. - cur.execute(f""" - CREATE TABLE IF NOT EXISTS {self.metadata_table} ( - id TEXT PRIMARY KEY, - chunk TEXT - ); - """) - # Create the virtual table for embeddings. - cur.execute(f""" - CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table} - USING vec0(embedding FLOAT[{self.dimension}], id TEXT); - """) - self.connection.commit() + def _init_tables(): + connection = _create_sqlite_connection(self.db_path) + cur = connection.cursor() + try: + # Create the table to store chunk metadata. + cur.execute(f""" + CREATE TABLE IF NOT EXISTS {self.metadata_table} ( + id TEXT PRIMARY KEY, + chunk TEXT + ); + """) + # Create the virtual table for embeddings. + cur.execute(f""" + CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table} + USING vec0(embedding FLOAT[{self.dimension}], id TEXT); + """) + connection.commit() + finally: + cur.close() + connection.close() - async def delete(self): - cur = self.connection.cursor() - cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};") - cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};") - self.connection.commit() + await asyncio.to_thread(_init_tables) + + async def delete(self) -> None: + def _drop_tables(): + connection = _create_sqlite_connection(self.db_path) + cur = connection.cursor() + try: + cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};") + cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};") + connection.commit() + finally: + cur.close() + connection.close() + + await asyncio.to_thread(_drop_tables) async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray, batch_size: int = 500): """ @@ -81,44 +107,55 @@ class SQLiteVecIndex(EmbeddingIndex): """ assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks" - cur = self.connection.cursor() - try: - # Start transaction - cur.execute("BEGIN TRANSACTION") - for i in range(0, len(chunks), batch_size): - batch_chunks = chunks[i : i + batch_size] - batch_embeddings = embeddings[i : i + batch_size] - # Prepare metadata inserts - metadata_data = [ - (generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json()) - for chunk in batch_chunks - if isinstance(chunk.content, str) - ] - # Insert metadata (ON CONFLICT to avoid duplicates) - cur.executemany( - f""" - INSERT INTO {self.metadata_table} (id, chunk) - VALUES (?, ?) - ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk; - """, - metadata_data, - ) - # Prepare embeddings inserts - embedding_data = [ - (generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist())) - for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True) - if isinstance(chunk.content, str) - ] - # Insert embeddings in batch - cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data) - self.connection.commit() + def _execute_all_batch_inserts(): + connection = _create_sqlite_connection(self.db_path) + cur = connection.cursor() - except sqlite3.Error as e: - self.connection.rollback() # Rollback on failure - logger.error(f"Error inserting into {self.vector_table}: {e}") + try: + # Start transaction a single transcation for all batches + cur.execute("BEGIN TRANSACTION") + for i in range(0, len(chunks), batch_size): + batch_chunks = chunks[i : i + batch_size] + batch_embeddings = embeddings[i : i + batch_size] + # Prepare metadata inserts + metadata_data = [ + (generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json()) + for chunk in batch_chunks + if isinstance(chunk.content, str) + ] + # Insert metadata (ON CONFLICT to avoid duplicates) + cur.executemany( + f""" + INSERT INTO {self.metadata_table} (id, chunk) + VALUES (?, ?) + ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk; + """, + metadata_data, + ) + # Prepare embeddings inserts + embedding_data = [ + ( + generate_chunk_id(chunk.metadata["document_id"], chunk.content), + serialize_vector(emb.tolist()), + ) + for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True) + if isinstance(chunk.content, str) + ] + # Insert embeddings in batch + cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data) + connection.commit() - finally: - cur.close() # Ensure cursor is closed + except sqlite3.Error as e: + connection.rollback() # Rollback on failure + logger.error(f"Error inserting into {self.vector_table}: {e}") + raise + + finally: + cur.close() + connection.close() + + # Process all batches in a single thread + await asyncio.to_thread(_execute_all_batch_inserts) async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: """ @@ -127,18 +164,28 @@ class SQLiteVecIndex(EmbeddingIndex): """ emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding) emb_blob = serialize_vector(emb_list) - cur = self.connection.cursor() - query_sql = f""" - SELECT m.id, m.chunk, v.distance - FROM {self.vector_table} AS v - JOIN {self.metadata_table} AS m ON m.id = v.id - WHERE v.embedding MATCH ? AND k = ? - ORDER BY v.distance; - """ - cur.execute(query_sql, (emb_blob, k)) - rows = cur.fetchall() - chunks = [] - scores = [] + + def _execute_query(): + connection = _create_sqlite_connection(self.db_path) + cur = connection.cursor() + + try: + query_sql = f""" + SELECT m.id, m.chunk, v.distance + FROM {self.vector_table} AS v + JOIN {self.metadata_table} AS m ON m.id = v.id + WHERE v.embedding MATCH ? AND k = ? + ORDER BY v.distance; + """ + cur.execute(query_sql, (emb_blob, k)) + return cur.fetchall() + finally: + cur.close() + connection.close() + + rows = await asyncio.to_thread(_execute_query) + + chunks, scores = [], [] for _id, chunk_json, distance in rows: try: chunk = Chunk.model_validate_json(chunk_json) @@ -163,63 +210,81 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): self.config = config self.inference_api = inference_api self.cache: Dict[str, VectorDBWithIndex] = {} - self.connection: Optional[sqlite3.Connection] = None async def initialize(self) -> None: - # Open a connection to the SQLite database (the file is specified in the config). - self.connection = sqlite3.connect(self.config.db_path) - self.connection.enable_load_extension(True) - sqlite_vec.load(self.connection) - self.connection.enable_load_extension(False) - cur = self.connection.cursor() - # Create a table to persist vector DB registrations. - cur.execute(""" - CREATE TABLE IF NOT EXISTS vector_dbs ( - id TEXT PRIMARY KEY, - metadata TEXT - ); - """) - self.connection.commit() - # Load any existing vector DB registrations. - cur.execute("SELECT metadata FROM vector_dbs") - rows = cur.fetchall() + def _setup_connection(): + # Open a connection to the SQLite database (the file is specified in the config). + connection = _create_sqlite_connection(self.config.db_path) + cur = connection.cursor() + try: + # Create a table to persist vector DB registrations. + cur.execute(""" + CREATE TABLE IF NOT EXISTS vector_dbs ( + id TEXT PRIMARY KEY, + metadata TEXT + ); + """) + connection.commit() + # Load any existing vector DB registrations. + cur.execute("SELECT metadata FROM vector_dbs") + rows = cur.fetchall() + return rows + finally: + cur.close() + connection.close() + + rows = await asyncio.to_thread(_setup_connection) for row in rows: vector_db_data = row[0] vector_db = VectorDB.model_validate_json(vector_db_data) - index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier) + index = await SQLiteVecIndex.create( + vector_db.embedding_dimension, self.config.db_path, vector_db.identifier + ) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) async def shutdown(self) -> None: - if self.connection: - self.connection.close() - self.connection = None + # nothing to do since we don't maintain a persistent connection + pass async def register_vector_db(self, vector_db: VectorDB) -> None: - if self.connection is None: - raise RuntimeError("SQLite connection not initialized") - cur = self.connection.cursor() - cur.execute( - "INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)", - (vector_db.identifier, vector_db.model_dump_json()), - ) - self.connection.commit() - index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier) + def _register_db(): + connection = _create_sqlite_connection(self.config.db_path) + cur = connection.cursor() + try: + cur.execute( + "INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)", + (vector_db.identifier, vector_db.model_dump_json()), + ) + connection.commit() + finally: + cur.close() + connection.close() + + await asyncio.to_thread(_register_db) + index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) async def list_vector_dbs(self) -> List[VectorDB]: return [v.vector_db for v in self.cache.values()] async def unregister_vector_db(self, vector_db_id: str) -> None: - if self.connection is None: - raise RuntimeError("SQLite connection not initialized") if vector_db_id not in self.cache: logger.warning(f"Vector DB {vector_db_id} not found") return await self.cache[vector_db_id].index.delete() del self.cache[vector_db_id] - cur = self.connection.cursor() - cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,)) - self.connection.commit() + + def _delete_vector_db_from_registry(): + connection = _create_sqlite_connection(self.config.db_path) + cur = connection.cursor() + try: + cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,)) + connection.commit() + finally: + cur.close() + connection.close() + + await asyncio.to_thread(_delete_vector_db_from_registry) async def insert_chunks(self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None) -> None: if vector_db_id not in self.cache: diff --git a/tests/unit/providers/vector_io/test_sqlite_vec.py b/tests/unit/providers/vector_io/test_sqlite_vec.py index cff988c53..32b60ffa5 100644 --- a/tests/unit/providers/vector_io/test_sqlite_vec.py +++ b/tests/unit/providers/vector_io/test_sqlite_vec.py @@ -5,17 +5,16 @@ # the root directory of this source tree. import asyncio -import sqlite3 import numpy as np import pytest import pytest_asyncio -import sqlite_vec from llama_stack.apis.vector_io import Chunk, QueryChunksResponse from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import ( SQLiteVecIndex, SQLiteVecVectorIOAdapter, + _create_sqlite_connection, generate_chunk_id, ) @@ -36,29 +35,25 @@ def loop(): return asyncio.new_event_loop() -@pytest.fixture(scope="session", autouse=True) -def sqlite_connection(loop): - conn = sqlite3.connect(":memory:") - try: - conn.enable_load_extension(True) - sqlite_vec.load(conn) - yield conn - finally: - conn.close() - - @pytest_asyncio.fixture(scope="session", autouse=True) -async def sqlite_vec_index(sqlite_connection, embedding_dimension): - return await SQLiteVecIndex.create(dimension=embedding_dimension, connection=sqlite_connection, bank_id="test_bank") +async def sqlite_vec_index(embedding_dimension, tmp_path_factory): + temp_dir = tmp_path_factory.getbasetemp() + db_path = str(temp_dir / "test_sqlite.db") + index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank") + yield index + await index.delete() @pytest.mark.asyncio async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings): await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2) - cur = sqlite_vec_index.connection.cursor() + connection = _create_sqlite_connection(sqlite_vec_index.db_path) + cur = connection.cursor() cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}") count = cur.fetchone()[0] assert count == len(sample_chunks) + cur.close() + connection.close() @pytest.mark.asyncio @@ -79,13 +74,14 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime sample_embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size) - - cur = sqlite_vec_index.connection.cursor() + connection = _create_sqlite_connection(sqlite_vec_index.db_path) + cur = connection.cursor() # Retrieve all chunk IDs to check for duplicates cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}") chunk_ids = [row[0] for row in cur.fetchall()] cur.close() + connection.close() # Ensure all chunk IDs are unique assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"