forked from phoenix-oss/llama-stack-mirror
feat: Chunk sqlite-vec writes (#1094)
# What does this PR do? 1. This PR adds batch inserts into sqlite-vec as requested in https://github.com/meta-llama/llama-stack/pull/1040 - Note: the inserts uses a uuid generated from the hash of the document id and chunk content. 2. This PR also adds unit tests for sqlite-vec. In a follow up PR, I can add similar tests to Faiss. ## Test Plan 1. Integration tests: ```python INFERENCE_MODEL=llama3.2:3b-instruct-fp16 LLAMA_STACK_CONFIG=ollama pytest -s -v tests/client-sdk/vector_io/test_vector_io.py ... PASSED tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_retrieve[all-MiniLM-L6-v2-sqlite_vec] PASSED tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_list PASSED tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_register[all-MiniLM-L6-v2-faiss] PASSED tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_register[all-MiniLM-L6-v2-sqlite_vec] PASSED tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_unregister[faiss] PASSED tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_unregister[sqlite_vec] PASSED ``` 3. Unit tests: ```python pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py -v -s --tb=short --disable-warnings --asyncio-mode=auto ... llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_add_chunks PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_query_chunks PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_chunk_id_conflict PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_register_vector_db PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_unregister_vector_db PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_generate_chunk_id PASSED ``` I also tested using the same example RAG script in https://github.com/meta-llama/llama-stack/pull/1040 and received the output. --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
26503ca1a4
commit
7972daa72e
2 changed files with 199 additions and 17 deletions
|
@ -4,9 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import sqlite3
|
||||
import struct
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
@ -52,14 +54,14 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
# Create the table to store chunk metadata.
|
||||
cur.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
|
||||
id INTEGER PRIMARY KEY,
|
||||
id TEXT PRIMARY KEY,
|
||||
chunk TEXT
|
||||
);
|
||||
""")
|
||||
# Create the virtual table for embeddings.
|
||||
cur.execute(f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
|
||||
USING vec0(embedding FLOAT[{self.dimension}]);
|
||||
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
||||
""")
|
||||
self.connection.commit()
|
||||
|
||||
|
@ -69,9 +71,9 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
|
||||
self.connection.commit()
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray, batch_size: int = 500):
|
||||
"""
|
||||
Add new chunks along with their embeddings.
|
||||
Add new chunks along with their embeddings using batch inserts.
|
||||
For each chunk, we insert its JSON into the metadata table and then insert its
|
||||
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
||||
If any insert fails, the transaction is rolled back to maintain consistency.
|
||||
|
@ -80,21 +82,35 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
try:
|
||||
# Start transaction
|
||||
cur.execute("BEGIN TRANSACTION")
|
||||
for chunk, emb in zip(chunks, embeddings, strict=False):
|
||||
# Serialize and insert the chunk metadata.
|
||||
chunk_json = chunk.model_dump_json()
|
||||
cur.execute(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", (chunk_json,))
|
||||
row_id = cur.lastrowid
|
||||
# Ensure the embedding is a list of floats.
|
||||
emb_list = emb.tolist() if isinstance(emb, np.ndarray) else list(emb)
|
||||
emb_blob = serialize_vector(emb_list)
|
||||
cur.execute(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", (row_id, emb_blob))
|
||||
# Commit transaction if all inserts succeed
|
||||
for i in range(0, len(chunks), batch_size):
|
||||
batch_chunks = chunks[i : i + batch_size]
|
||||
batch_embeddings = embeddings[i : i + batch_size]
|
||||
# Prepare metadata inserts
|
||||
metadata_data = [
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
||||
for chunk in batch_chunks
|
||||
]
|
||||
# Insert metadata (ON CONFLICT to avoid duplicates)
|
||||
cur.executemany(
|
||||
f"""
|
||||
INSERT INTO {self.metadata_table} (id, chunk)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
|
||||
""",
|
||||
metadata_data,
|
||||
)
|
||||
# Prepare embeddings inserts
|
||||
embedding_data = [
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
|
||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||
]
|
||||
# Insert embeddings in batch
|
||||
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
|
||||
self.connection.commit()
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self.connection.rollback() # Rollback on failure
|
||||
print(f"Error inserting into {self.vector_table} - error: {e}") # Log error (Consider using logging module)
|
||||
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
||||
|
||||
finally:
|
||||
cur.close() # Ensure cursor is closed
|
||||
|
@ -110,7 +126,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
query_sql = f"""
|
||||
SELECT m.id, m.chunk, v.distance
|
||||
FROM {self.vector_table} AS v
|
||||
JOIN {self.metadata_table} AS m ON m.id = v.rowid
|
||||
JOIN {self.metadata_table} AS m ON m.id = v.id
|
||||
WHERE v.embedding MATCH ? AND k = ?
|
||||
ORDER BY v.distance;
|
||||
"""
|
||||
|
@ -204,7 +220,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
if vector_db_id not in self.cache:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
|
||||
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
|
||||
# and then call our index’s add_chunks.
|
||||
# and then call our index's add_chunks.
|
||||
await self.cache[vector_db_id].insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
|
@ -213,3 +229,9 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
if vector_db_id not in self.cache:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
return await self.cache[vector_db_id].query_chunks(query, params)
|
||||
|
||||
|
||||
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
||||
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
|
||||
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
|
||||
|
|
160
llama_stack/providers/tests/vector_io/test_sqlite_vec.py
Normal file
160
llama_stack/providers/tests/vector_io/test_sqlite_vec.py
Normal file
|
@ -0,0 +1,160 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import sqlite3
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import sqlite_vec
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
|
||||
SQLiteVecIndex,
|
||||
SQLiteVecVectorIOAdapter,
|
||||
generate_chunk_id,
|
||||
)
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py \
|
||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
SQLITE_VEC_PROVIDER = "sqlite_vec"
|
||||
EMBEDDING_DIMENSION = 384
|
||||
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loop():
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def sqlite_connection(loop):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
try:
|
||||
conn.enable_load_extension(True)
|
||||
sqlite_vec.load(conn)
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
async def sqlite_vec_index(sqlite_connection):
|
||||
return await SQLiteVecIndex.create(dimension=EMBEDDING_DIMENSION, connection=sqlite_connection, bank_id="test_bank")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_chunks():
|
||||
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
|
||||
n, k = 10, 3
|
||||
sample = [
|
||||
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
|
||||
for j in range(k)
|
||||
for i in range(n)
|
||||
]
|
||||
return sample
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_embeddings(sample_chunks):
|
||||
np.random.seed(42)
|
||||
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
|
||||
cur = sqlite_vec_index.connection.cursor()
|
||||
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
|
||||
count = cur.fetchone()[0]
|
||||
assert count == len(sample_chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_embedding = np.random.rand(EMBEDDING_DIMENSION).astype(np.float32)
|
||||
response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0)
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks):
|
||||
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
|
||||
# Reduce batch size to force multiple batches for same document
|
||||
# since there are 10 chunks per document and batch size is 2
|
||||
batch_size = 2
|
||||
sample_embeddings = np.random.rand(len(sample_chunks), EMBEDDING_DIMENSION).astype(np.float32)
|
||||
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)
|
||||
|
||||
cur = sqlite_vec_index.connection.cursor()
|
||||
|
||||
# Retrieve all chunk IDs to check for duplicates
|
||||
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
|
||||
chunk_ids = [row[0] for row in cur.fetchall()]
|
||||
cur.close()
|
||||
|
||||
# Ensure all chunk IDs are unique
|
||||
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def sqlite_vec_adapter(sqlite_connection):
|
||||
config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database
|
||||
adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None)
|
||||
await adapter.initialize()
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_vector_db(sqlite_vec_adapter):
|
||||
vector_db = VectorDB(
|
||||
identifier="test_db",
|
||||
embedding_model=EMBEDDING_MODEL,
|
||||
embedding_dimension=EMBEDDING_DIMENSION,
|
||||
metadata={},
|
||||
provider_id=SQLITE_VEC_PROVIDER,
|
||||
)
|
||||
await sqlite_vec_adapter.register_vector_db(vector_db)
|
||||
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
|
||||
assert any(db.identifier == "test_db" for db in vector_dbs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_vector_db(sqlite_vec_adapter):
|
||||
vector_db = VectorDB(
|
||||
identifier="test_db",
|
||||
embedding_model=EMBEDDING_MODEL,
|
||||
embedding_dimension=EMBEDDING_DIMENSION,
|
||||
metadata={},
|
||||
provider_id=SQLITE_VEC_PROVIDER,
|
||||
)
|
||||
await sqlite_vec_adapter.register_vector_db(vector_db)
|
||||
await sqlite_vec_adapter.unregister_vector_db("test_db")
|
||||
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
|
||||
assert not any(db.identifier == "test_db" for db in vector_dbs)
|
||||
|
||||
|
||||
def test_generate_chunk_id():
|
||||
chunks = [
|
||||
Chunk(content="test", metadata={"document_id": "doc-1"}),
|
||||
Chunk(content="test ", metadata={"document_id": "doc-1"}),
|
||||
Chunk(content="test 3", metadata={"document_id": "doc-1"}),
|
||||
]
|
||||
|
||||
chunk_ids = sorted([generate_chunk_id(chunk.metadata["document_id"], chunk.content) for chunk in chunks])
|
||||
assert chunk_ids == [
|
||||
"177a1368-f6a8-0c50-6e92-18677f2c3de3",
|
||||
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
|
||||
"f68df25d-d9aa-ab4d-5684-64a233add20d",
|
||||
]
|
Loading…
Add table
Add a link
Reference in a new issue