updating to use a hash of document id and text content to guarantee uniqueness

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-02-18 19:53:31 -05:00
parent 5a6c95ecf9
commit 898d325772
2 changed files with 68 additions and 27 deletions

View file

@ -4,9 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import hashlib
import logging
import sqlite3
import struct
import uuid
from typing import Any, Dict, List, Optional
import numpy as np
@ -80,13 +82,13 @@ class SQLiteVecIndex(EmbeddingIndex):
try:
# Start transaction
cur.execute("BEGIN TRANSACTION")
for i in range(0, len(chunks), batch_size):
for k, i in enumerate(range(0, len(chunks), batch_size)):
batch_chunks = chunks[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size]
# Prepare metadata inserts
metadata_data = [
(f"{chunk.metadata['document_id']}:chunk-{j}", chunk.model_dump_json())
for j, chunk in enumerate(batch_chunks)
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
for chunk in batch_chunks
]
# Insert metadata (ON CONFLICT to avoid duplicates)
cur.executemany(
@ -99,8 +101,8 @@ class SQLiteVecIndex(EmbeddingIndex):
)
# Prepare embeddings inserts
embedding_data = [
(f"{chunk.metadata['document_id']}:chunk-{j}", serialize_vector(emb.tolist()))
for j, (chunk, emb) in enumerate(zip(batch_chunks, batch_embeddings, strict=True))
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
]
# Insert embeddings in batch
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
@ -227,3 +229,9 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params)
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))

View file

@ -13,7 +13,11 @@ import sqlite_vec
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
SQLiteVecIndex,
SQLiteVecVectorIOAdapter,
generate_chunk_id,
)
# How to run this test:
#
@ -46,34 +50,27 @@ async def sqlite_vec_index(sqlite_connection):
return await SQLiteVecIndex.create(dimension=EMBEDDING_DIMENSION, connection=sqlite_connection, bank_id="test_bank")
@pytest.fixture
@pytest.fixture(scope="session")
def sample_chunks():
return [
Chunk(
content="Python is a high-level programming language.",
metadata={"category": "programming", "document_id": "doc 1"},
),
Chunk(
content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "document_id": "doc 1"},
),
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
n, k = 10, 3
sample = [
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
for j in range(k)
for i in range(n)
]
return sample
@pytest.fixture
def sample_embeddings():
@pytest.fixture(scope="session")
def sample_embeddings(sample_chunks):
np.random.seed(42)
return np.array(
[
np.random.rand(EMBEDDING_DIMENSION).astype(np.float32),
np.random.rand(EMBEDDING_DIMENSION).astype(np.float32),
]
)
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
@pytest.mark.asyncio
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
cur = sqlite_vec_index.connection.cursor()
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
count = cur.fetchone()[0]
@ -84,9 +81,30 @@ async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = np.random.rand(EMBEDDING_DIMENSION).astype(np.float32)
response = await sqlite_vec_index.query(query_embedding, k=1, score_threshold=0.0)
response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) > 0
assert len(response.chunks) == 2
@pytest.mark.asyncio
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks):
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
# Reduce batch size to force multiple batches for same document
# since there are 10 chunks per document and batch size is 2
batch_size = 2
sample_embeddings = np.random.rand(len(sample_chunks), EMBEDDING_DIMENSION).astype(np.float32)
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)
cur = sqlite_vec_index.connection.cursor()
# Retrieve all chunk IDs to check for duplicates
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
chunk_ids = [row[0] for row in cur.fetchall()]
cur.close()
# Ensure all chunk IDs are unique
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
@pytest.fixture(scope="session")
@ -125,3 +143,18 @@ async def test_unregister_vector_db(sqlite_vec_adapter):
await sqlite_vec_adapter.unregister_vector_db("test_db")
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
assert not any(db.identifier == "test_db" for db in vector_dbs)
def test_generate_chunk_id():
chunks = [
Chunk(content="test", metadata={"document_id": "doc-1"}),
Chunk(content="test ", metadata={"document_id": "doc-1"}),
Chunk(content="test 3", metadata={"document_id": "doc-1"}),
]
chunk_ids = sorted([generate_chunk_id(chunk.metadata["document_id"], chunk.content) for chunk in chunks])
assert chunk_ids == [
"177a1368-f6a8-0c50-6e92-18677f2c3de3",
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
"f68df25d-d9aa-ab4d-5684-64a233add20d",
]