forked from phoenix-oss/llama-stack-mirror
chore: Updating sqlite-vec to make non-blocking calls (#1762)
# What does this PR do? This PR updates the sqlite-vec database calls to be non-blocking. Note that each operation creates a new connection, which incurs some performance overhead but is reasonable given [SQLite's threading and connections constraints](https://www.sqlite.org/threadsafe.html). Summary of changes: - Refactored `SQLiteVecIndex` class to store database path instead of connection object - Added `_create_sqlite_connection()` helper function to create connections on demand - Ensured proper connection closure in all database operations - Fixed test fixtures to use a file-based SQLite database for thread-safety - Updated the `SQLiteVecVectorIOAdapter` class to handle per-operation connections This PR helps chip away at https://github.com/meta-llama/llama-stack/issues/1489 ## Test Plan sqlite-vec unit tests passed locally as well as a test script using the client as a library. ## Misc FYI @varshaprasad96 @kevincogan Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
094eb6a5ae
commit
9e1ddf2b53
2 changed files with 186 additions and 125 deletions
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import sqlite3
|
||||
|
@ -29,6 +30,15 @@ def serialize_vector(vector: List[float]) -> bytes:
|
|||
return struct.pack(f"{len(vector)}f", *vector)
|
||||
|
||||
|
||||
def _create_sqlite_connection(db_path):
|
||||
"""Create a SQLite connection with sqlite_vec extension loaded."""
|
||||
connection = sqlite3.connect(db_path)
|
||||
connection.enable_load_extension(True)
|
||||
sqlite_vec.load(connection)
|
||||
connection.enable_load_extension(False)
|
||||
return connection
|
||||
|
||||
|
||||
class SQLiteVecIndex(EmbeddingIndex):
|
||||
"""
|
||||
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
|
||||
|
@ -37,40 +47,56 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
- A virtual table (vec_chunks_{bank_id}) that holds the serialized vector.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int, connection: sqlite3.Connection, bank_id: str):
|
||||
def __init__(self, dimension: int, db_path: str, bank_id: str):
|
||||
self.dimension = dimension
|
||||
self.connection = connection
|
||||
self.db_path = db_path
|
||||
self.bank_id = bank_id
|
||||
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
|
||||
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
|
||||
|
||||
@classmethod
|
||||
async def create(cls, dimension: int, connection: sqlite3.Connection, bank_id: str):
|
||||
instance = cls(dimension, connection, bank_id)
|
||||
async def create(cls, dimension: int, db_path: str, bank_id: str):
|
||||
instance = cls(dimension, db_path, bank_id)
|
||||
await instance.initialize()
|
||||
return instance
|
||||
|
||||
async def initialize(self) -> None:
|
||||
cur = self.connection.cursor()
|
||||
# Create the table to store chunk metadata.
|
||||
cur.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
|
||||
id TEXT PRIMARY KEY,
|
||||
chunk TEXT
|
||||
);
|
||||
""")
|
||||
# Create the virtual table for embeddings.
|
||||
cur.execute(f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
|
||||
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
||||
""")
|
||||
self.connection.commit()
|
||||
def _init_tables():
|
||||
connection = _create_sqlite_connection(self.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
# Create the table to store chunk metadata.
|
||||
cur.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
|
||||
id TEXT PRIMARY KEY,
|
||||
chunk TEXT
|
||||
);
|
||||
""")
|
||||
# Create the virtual table for embeddings.
|
||||
cur.execute(f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
|
||||
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
||||
""")
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
async def delete(self):
|
||||
cur = self.connection.cursor()
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
|
||||
self.connection.commit()
|
||||
await asyncio.to_thread(_init_tables)
|
||||
|
||||
async def delete(self) -> None:
|
||||
def _drop_tables():
|
||||
connection = _create_sqlite_connection(self.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_drop_tables)
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray, batch_size: int = 500):
|
||||
"""
|
||||
|
@ -81,44 +107,55 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
"""
|
||||
assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks"
|
||||
|
||||
cur = self.connection.cursor()
|
||||
try:
|
||||
# Start transaction
|
||||
cur.execute("BEGIN TRANSACTION")
|
||||
for i in range(0, len(chunks), batch_size):
|
||||
batch_chunks = chunks[i : i + batch_size]
|
||||
batch_embeddings = embeddings[i : i + batch_size]
|
||||
# Prepare metadata inserts
|
||||
metadata_data = [
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
||||
for chunk in batch_chunks
|
||||
if isinstance(chunk.content, str)
|
||||
]
|
||||
# Insert metadata (ON CONFLICT to avoid duplicates)
|
||||
cur.executemany(
|
||||
f"""
|
||||
INSERT INTO {self.metadata_table} (id, chunk)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
|
||||
""",
|
||||
metadata_data,
|
||||
)
|
||||
# Prepare embeddings inserts
|
||||
embedding_data = [
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
|
||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||
if isinstance(chunk.content, str)
|
||||
]
|
||||
# Insert embeddings in batch
|
||||
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
|
||||
self.connection.commit()
|
||||
def _execute_all_batch_inserts():
|
||||
connection = _create_sqlite_connection(self.db_path)
|
||||
cur = connection.cursor()
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self.connection.rollback() # Rollback on failure
|
||||
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
||||
try:
|
||||
# Start transaction a single transcation for all batches
|
||||
cur.execute("BEGIN TRANSACTION")
|
||||
for i in range(0, len(chunks), batch_size):
|
||||
batch_chunks = chunks[i : i + batch_size]
|
||||
batch_embeddings = embeddings[i : i + batch_size]
|
||||
# Prepare metadata inserts
|
||||
metadata_data = [
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
||||
for chunk in batch_chunks
|
||||
if isinstance(chunk.content, str)
|
||||
]
|
||||
# Insert metadata (ON CONFLICT to avoid duplicates)
|
||||
cur.executemany(
|
||||
f"""
|
||||
INSERT INTO {self.metadata_table} (id, chunk)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
|
||||
""",
|
||||
metadata_data,
|
||||
)
|
||||
# Prepare embeddings inserts
|
||||
embedding_data = [
|
||||
(
|
||||
generate_chunk_id(chunk.metadata["document_id"], chunk.content),
|
||||
serialize_vector(emb.tolist()),
|
||||
)
|
||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||
if isinstance(chunk.content, str)
|
||||
]
|
||||
# Insert embeddings in batch
|
||||
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
|
||||
connection.commit()
|
||||
|
||||
finally:
|
||||
cur.close() # Ensure cursor is closed
|
||||
except sqlite3.Error as e:
|
||||
connection.rollback() # Rollback on failure
|
||||
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
# Process all batches in a single thread
|
||||
await asyncio.to_thread(_execute_all_batch_inserts)
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
|
@ -127,18 +164,28 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
"""
|
||||
emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
|
||||
emb_blob = serialize_vector(emb_list)
|
||||
cur = self.connection.cursor()
|
||||
query_sql = f"""
|
||||
SELECT m.id, m.chunk, v.distance
|
||||
FROM {self.vector_table} AS v
|
||||
JOIN {self.metadata_table} AS m ON m.id = v.id
|
||||
WHERE v.embedding MATCH ? AND k = ?
|
||||
ORDER BY v.distance;
|
||||
"""
|
||||
cur.execute(query_sql, (emb_blob, k))
|
||||
rows = cur.fetchall()
|
||||
chunks = []
|
||||
scores = []
|
||||
|
||||
def _execute_query():
|
||||
connection = _create_sqlite_connection(self.db_path)
|
||||
cur = connection.cursor()
|
||||
|
||||
try:
|
||||
query_sql = f"""
|
||||
SELECT m.id, m.chunk, v.distance
|
||||
FROM {self.vector_table} AS v
|
||||
JOIN {self.metadata_table} AS m ON m.id = v.id
|
||||
WHERE v.embedding MATCH ? AND k = ?
|
||||
ORDER BY v.distance;
|
||||
"""
|
||||
cur.execute(query_sql, (emb_blob, k))
|
||||
return cur.fetchall()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
rows = await asyncio.to_thread(_execute_query)
|
||||
|
||||
chunks, scores = [], []
|
||||
for _id, chunk_json, distance in rows:
|
||||
try:
|
||||
chunk = Chunk.model_validate_json(chunk_json)
|
||||
|
@ -163,63 +210,81 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cache: Dict[str, VectorDBWithIndex] = {}
|
||||
self.connection: Optional[sqlite3.Connection] = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
# Open a connection to the SQLite database (the file is specified in the config).
|
||||
self.connection = sqlite3.connect(self.config.db_path)
|
||||
self.connection.enable_load_extension(True)
|
||||
sqlite_vec.load(self.connection)
|
||||
self.connection.enable_load_extension(False)
|
||||
cur = self.connection.cursor()
|
||||
# Create a table to persist vector DB registrations.
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS vector_dbs (
|
||||
id TEXT PRIMARY KEY,
|
||||
metadata TEXT
|
||||
);
|
||||
""")
|
||||
self.connection.commit()
|
||||
# Load any existing vector DB registrations.
|
||||
cur.execute("SELECT metadata FROM vector_dbs")
|
||||
rows = cur.fetchall()
|
||||
def _setup_connection():
|
||||
# Open a connection to the SQLite database (the file is specified in the config).
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
# Create a table to persist vector DB registrations.
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS vector_dbs (
|
||||
id TEXT PRIMARY KEY,
|
||||
metadata TEXT
|
||||
);
|
||||
""")
|
||||
connection.commit()
|
||||
# Load any existing vector DB registrations.
|
||||
cur.execute("SELECT metadata FROM vector_dbs")
|
||||
rows = cur.fetchall()
|
||||
return rows
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
rows = await asyncio.to_thread(_setup_connection)
|
||||
for row in rows:
|
||||
vector_db_data = row[0]
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier)
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.connection:
|
||||
self.connection.close()
|
||||
self.connection = None
|
||||
# nothing to do since we don't maintain a persistent connection
|
||||
pass
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
if self.connection is None:
|
||||
raise RuntimeError("SQLite connection not initialized")
|
||||
cur = self.connection.cursor()
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)",
|
||||
(vector_db.identifier, vector_db.model_dump_json()),
|
||||
)
|
||||
self.connection.commit()
|
||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier)
|
||||
def _register_db():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)",
|
||||
(vector_db.identifier, vector_db.model_dump_json()),
|
||||
)
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_register_db)
|
||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
async def list_vector_dbs(self) -> List[VectorDB]:
|
||||
return [v.vector_db for v in self.cache.values()]
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if self.connection is None:
|
||||
raise RuntimeError("SQLite connection not initialized")
|
||||
if vector_db_id not in self.cache:
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
return
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
cur = self.connection.cursor()
|
||||
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
|
||||
self.connection.commit()
|
||||
|
||||
def _delete_vector_db_from_registry():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_delete_vector_db_from_registry)
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
|
|
|
@ -5,17 +5,16 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import sqlite3
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import sqlite_vec
|
||||
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
|
||||
SQLiteVecIndex,
|
||||
SQLiteVecVectorIOAdapter,
|
||||
_create_sqlite_connection,
|
||||
generate_chunk_id,
|
||||
)
|
||||
|
||||
|
@ -36,29 +35,25 @@ def loop():
|
|||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def sqlite_connection(loop):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
try:
|
||||
conn.enable_load_extension(True)
|
||||
sqlite_vec.load(conn)
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
async def sqlite_vec_index(sqlite_connection, embedding_dimension):
|
||||
return await SQLiteVecIndex.create(dimension=embedding_dimension, connection=sqlite_connection, bank_id="test_bank")
|
||||
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / "test_sqlite.db")
|
||||
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
|
||||
yield index
|
||||
await index.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
|
||||
cur = sqlite_vec_index.connection.cursor()
|
||||
connection = _create_sqlite_connection(sqlite_vec_index.db_path)
|
||||
cur = connection.cursor()
|
||||
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
|
||||
count = cur.fetchone()[0]
|
||||
assert count == len(sample_chunks)
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -79,13 +74,14 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime
|
|||
sample_embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
|
||||
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)
|
||||
|
||||
cur = sqlite_vec_index.connection.cursor()
|
||||
connection = _create_sqlite_connection(sqlite_vec_index.db_path)
|
||||
cur = connection.cursor()
|
||||
|
||||
# Retrieve all chunk IDs to check for duplicates
|
||||
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
|
||||
chunk_ids = [row[0] for row in cur.fetchall()]
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
# Ensure all chunk IDs are unique
|
||||
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue