forked from phoenix-oss/llama-stack-mirror
feat: Chunk sqlite-vec writes (#1094)
# What does this PR do? 1. This PR adds batch inserts into sqlite-vec as requested in https://github.com/meta-llama/llama-stack/pull/1040 - Note: the inserts uses a uuid generated from the hash of the document id and chunk content. 2. This PR also adds unit tests for sqlite-vec. In a follow up PR, I can add similar tests to Faiss. ## Test Plan 1. Integration tests: ```python INFERENCE_MODEL=llama3.2:3b-instruct-fp16 LLAMA_STACK_CONFIG=ollama pytest -s -v tests/client-sdk/vector_io/test_vector_io.py ... PASSED tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_retrieve[all-MiniLM-L6-v2-sqlite_vec] PASSED tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_list PASSED tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_register[all-MiniLM-L6-v2-faiss] PASSED tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_register[all-MiniLM-L6-v2-sqlite_vec] PASSED tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_unregister[faiss] PASSED tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_unregister[sqlite_vec] PASSED ``` 3. Unit tests: ```python pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py -v -s --tb=short --disable-warnings --asyncio-mode=auto ... llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_add_chunks PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_query_chunks PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_chunk_id_conflict PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_register_vector_db PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_unregister_vector_db PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_generate_chunk_id PASSED ``` I also tested using the same example RAG script in https://github.com/meta-llama/llama-stack/pull/1040 and received the output. --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
26503ca1a4
commit
7972daa72e
2 changed files with 199 additions and 17 deletions
|
@ -4,9 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import sqlite3
|
||||
import struct
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
@ -52,14 +54,14 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
# Create the table to store chunk metadata.
|
||||
cur.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
|
||||
id INTEGER PRIMARY KEY,
|
||||
id TEXT PRIMARY KEY,
|
||||
chunk TEXT
|
||||
);
|
||||
""")
|
||||
# Create the virtual table for embeddings.
|
||||
cur.execute(f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
|
||||
USING vec0(embedding FLOAT[{self.dimension}]);
|
||||
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
||||
""")
|
||||
self.connection.commit()
|
||||
|
||||
|
@ -69,9 +71,9 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
|
||||
self.connection.commit()
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray, batch_size: int = 500):
|
||||
"""
|
||||
Add new chunks along with their embeddings.
|
||||
Add new chunks along with their embeddings using batch inserts.
|
||||
For each chunk, we insert its JSON into the metadata table and then insert its
|
||||
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
||||
If any insert fails, the transaction is rolled back to maintain consistency.
|
||||
|
@ -80,21 +82,35 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
try:
|
||||
# Start transaction
|
||||
cur.execute("BEGIN TRANSACTION")
|
||||
for chunk, emb in zip(chunks, embeddings, strict=False):
|
||||
# Serialize and insert the chunk metadata.
|
||||
chunk_json = chunk.model_dump_json()
|
||||
cur.execute(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", (chunk_json,))
|
||||
row_id = cur.lastrowid
|
||||
# Ensure the embedding is a list of floats.
|
||||
emb_list = emb.tolist() if isinstance(emb, np.ndarray) else list(emb)
|
||||
emb_blob = serialize_vector(emb_list)
|
||||
cur.execute(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", (row_id, emb_blob))
|
||||
# Commit transaction if all inserts succeed
|
||||
for i in range(0, len(chunks), batch_size):
|
||||
batch_chunks = chunks[i : i + batch_size]
|
||||
batch_embeddings = embeddings[i : i + batch_size]
|
||||
# Prepare metadata inserts
|
||||
metadata_data = [
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
||||
for chunk in batch_chunks
|
||||
]
|
||||
# Insert metadata (ON CONFLICT to avoid duplicates)
|
||||
cur.executemany(
|
||||
f"""
|
||||
INSERT INTO {self.metadata_table} (id, chunk)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
|
||||
""",
|
||||
metadata_data,
|
||||
)
|
||||
# Prepare embeddings inserts
|
||||
embedding_data = [
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
|
||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||
]
|
||||
# Insert embeddings in batch
|
||||
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
|
||||
self.connection.commit()
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self.connection.rollback() # Rollback on failure
|
||||
print(f"Error inserting into {self.vector_table} - error: {e}") # Log error (Consider using logging module)
|
||||
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
||||
|
||||
finally:
|
||||
cur.close() # Ensure cursor is closed
|
||||
|
@ -110,7 +126,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
query_sql = f"""
|
||||
SELECT m.id, m.chunk, v.distance
|
||||
FROM {self.vector_table} AS v
|
||||
JOIN {self.metadata_table} AS m ON m.id = v.rowid
|
||||
JOIN {self.metadata_table} AS m ON m.id = v.id
|
||||
WHERE v.embedding MATCH ? AND k = ?
|
||||
ORDER BY v.distance;
|
||||
"""
|
||||
|
@ -204,7 +220,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
if vector_db_id not in self.cache:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
|
||||
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
|
||||
# and then call our index’s add_chunks.
|
||||
# and then call our index's add_chunks.
|
||||
await self.cache[vector_db_id].insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
|
@ -213,3 +229,9 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
if vector_db_id not in self.cache:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
return await self.cache[vector_db_id].query_chunks(query, params)
|
||||
|
||||
|
||||
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
||||
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
|
||||
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue