updating to use a hash of document id and text content to guarantee uniqueness

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-02-18 19:53:31 -05:00
parent 5a6c95ecf9
commit 898d325772
2 changed files with 68 additions and 27 deletions

View file

@ -4,9 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import hashlib
import logging import logging
import sqlite3 import sqlite3
import struct import struct
import uuid
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import numpy as np import numpy as np
@ -80,13 +82,13 @@ class SQLiteVecIndex(EmbeddingIndex):
try: try:
# Start transaction # Start transaction
cur.execute("BEGIN TRANSACTION") cur.execute("BEGIN TRANSACTION")
for i in range(0, len(chunks), batch_size): for k, i in enumerate(range(0, len(chunks), batch_size)):
batch_chunks = chunks[i : i + batch_size] batch_chunks = chunks[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size] batch_embeddings = embeddings[i : i + batch_size]
# Prepare metadata inserts # Prepare metadata inserts
metadata_data = [ metadata_data = [
(f"{chunk.metadata['document_id']}:chunk-{j}", chunk.model_dump_json()) (generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
for j, chunk in enumerate(batch_chunks) for chunk in batch_chunks
] ]
# Insert metadata (ON CONFLICT to avoid duplicates) # Insert metadata (ON CONFLICT to avoid duplicates)
cur.executemany( cur.executemany(
@ -99,8 +101,8 @@ class SQLiteVecIndex(EmbeddingIndex):
) )
# Prepare embeddings inserts # Prepare embeddings inserts
embedding_data = [ embedding_data = [
(f"{chunk.metadata['document_id']}:chunk-{j}", serialize_vector(emb.tolist())) (generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
for j, (chunk, emb) in enumerate(zip(batch_chunks, batch_embeddings, strict=True)) for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
] ]
# Insert embeddings in batch # Insert embeddings in batch
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data) cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
@ -227,3 +229,9 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
if vector_db_id not in self.cache: if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params) return await self.cache[vector_db_id].query_chunks(query, params)
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))

View file

@ -13,7 +13,11 @@ import sqlite_vec
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
SQLiteVecIndex,
SQLiteVecVectorIOAdapter,
generate_chunk_id,
)
# How to run this test: # How to run this test:
# #
@ -46,34 +50,27 @@ async def sqlite_vec_index(sqlite_connection):
return await SQLiteVecIndex.create(dimension=EMBEDDING_DIMENSION, connection=sqlite_connection, bank_id="test_bank") return await SQLiteVecIndex.create(dimension=EMBEDDING_DIMENSION, connection=sqlite_connection, bank_id="test_bank")
@pytest.fixture @pytest.fixture(scope="session")
def sample_chunks(): def sample_chunks():
return [ """Generates chunks that force multiple batches for a single document to expose ID conflicts."""
Chunk( n, k = 10, 3
content="Python is a high-level programming language.", sample = [
metadata={"category": "programming", "document_id": "doc 1"}, Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
), for j in range(k)
Chunk( for i in range(n)
content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "document_id": "doc 1"},
),
] ]
return sample
@pytest.fixture @pytest.fixture(scope="session")
def sample_embeddings(): def sample_embeddings(sample_chunks):
np.random.seed(42) np.random.seed(42)
return np.array( return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
[
np.random.rand(EMBEDDING_DIMENSION).astype(np.float32),
np.random.rand(EMBEDDING_DIMENSION).astype(np.float32),
]
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
cur = sqlite_vec_index.connection.cursor() cur = sqlite_vec_index.connection.cursor()
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}") cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
count = cur.fetchone()[0] count = cur.fetchone()[0]
@ -84,9 +81,30 @@ async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) query_embedding = np.random.rand(EMBEDDING_DIMENSION).astype(np.float32)
response = await sqlite_vec_index.query(query_embedding, k=1, score_threshold=0.0) response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse) assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) > 0 assert len(response.chunks) == 2
@pytest.mark.asyncio
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks):
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
# Reduce batch size to force multiple batches for same document
# since there are 10 chunks per document and batch size is 2
batch_size = 2
sample_embeddings = np.random.rand(len(sample_chunks), EMBEDDING_DIMENSION).astype(np.float32)
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)
cur = sqlite_vec_index.connection.cursor()
# Retrieve all chunk IDs to check for duplicates
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
chunk_ids = [row[0] for row in cur.fetchall()]
cur.close()
# Ensure all chunk IDs are unique
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -125,3 +143,18 @@ async def test_unregister_vector_db(sqlite_vec_adapter):
await sqlite_vec_adapter.unregister_vector_db("test_db") await sqlite_vec_adapter.unregister_vector_db("test_db")
vector_dbs = await sqlite_vec_adapter.list_vector_dbs() vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
assert not any(db.identifier == "test_db" for db in vector_dbs) assert not any(db.identifier == "test_db" for db in vector_dbs)
def test_generate_chunk_id():
chunks = [
Chunk(content="test", metadata={"document_id": "doc-1"}),
Chunk(content="test ", metadata={"document_id": "doc-1"}),
Chunk(content="test 3", metadata={"document_id": "doc-1"}),
]
chunk_ids = sorted([generate_chunk_id(chunk.metadata["document_id"], chunk.content) for chunk in chunks])
assert chunk_ids == [
"177a1368-f6a8-0c50-6e92-18677f2c3de3",
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
"f68df25d-d9aa-ab4d-5684-64a233add20d",
]