mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
updating to use a hash of document id and text content to guarantee uniqueness
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
5a6c95ecf9
commit
898d325772
2 changed files with 68 additions and 27 deletions
|
@ -4,9 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import sqlite3
|
||||
import struct
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
@ -80,13 +82,13 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
try:
|
||||
# Start transaction
|
||||
cur.execute("BEGIN TRANSACTION")
|
||||
for i in range(0, len(chunks), batch_size):
|
||||
for k, i in enumerate(range(0, len(chunks), batch_size)):
|
||||
batch_chunks = chunks[i : i + batch_size]
|
||||
batch_embeddings = embeddings[i : i + batch_size]
|
||||
# Prepare metadata inserts
|
||||
metadata_data = [
|
||||
(f"{chunk.metadata['document_id']}:chunk-{j}", chunk.model_dump_json())
|
||||
for j, chunk in enumerate(batch_chunks)
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
||||
for chunk in batch_chunks
|
||||
]
|
||||
# Insert metadata (ON CONFLICT to avoid duplicates)
|
||||
cur.executemany(
|
||||
|
@ -99,8 +101,8 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
)
|
||||
# Prepare embeddings inserts
|
||||
embedding_data = [
|
||||
(f"{chunk.metadata['document_id']}:chunk-{j}", serialize_vector(emb.tolist()))
|
||||
for j, (chunk, emb) in enumerate(zip(batch_chunks, batch_embeddings, strict=True))
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
|
||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||
]
|
||||
# Insert embeddings in batch
|
||||
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
|
||||
|
@ -227,3 +229,9 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
if vector_db_id not in self.cache:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
return await self.cache[vector_db_id].query_chunks(query, params)
|
||||
|
||||
|
||||
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
||||
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
|
||||
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
|
||||
|
|
|
@ -13,7 +13,11 @@ import sqlite_vec
|
|||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
|
||||
SQLiteVecIndex,
|
||||
SQLiteVecVectorIOAdapter,
|
||||
generate_chunk_id,
|
||||
)
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
|
@ -46,34 +50,27 @@ async def sqlite_vec_index(sqlite_connection):
|
|||
return await SQLiteVecIndex.create(dimension=EMBEDDING_DIMENSION, connection=sqlite_connection, bank_id="test_bank")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_chunks():
|
||||
return [
|
||||
Chunk(
|
||||
content="Python is a high-level programming language.",
|
||||
metadata={"category": "programming", "document_id": "doc 1"},
|
||||
),
|
||||
Chunk(
|
||||
content="Machine learning is a subset of artificial intelligence.",
|
||||
metadata={"category": "AI", "document_id": "doc 1"},
|
||||
),
|
||||
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
|
||||
n, k = 10, 3
|
||||
sample = [
|
||||
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
|
||||
for j in range(k)
|
||||
for i in range(n)
|
||||
]
|
||||
return sample
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embeddings():
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_embeddings(sample_chunks):
|
||||
np.random.seed(42)
|
||||
return np.array(
|
||||
[
|
||||
np.random.rand(EMBEDDING_DIMENSION).astype(np.float32),
|
||||
np.random.rand(EMBEDDING_DIMENSION).astype(np.float32),
|
||||
]
|
||||
)
|
||||
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
|
||||
cur = sqlite_vec_index.connection.cursor()
|
||||
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
|
||||
count = cur.fetchone()[0]
|
||||
|
@ -84,9 +81,30 @@ async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
|||
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_embedding = np.random.rand(EMBEDDING_DIMENSION).astype(np.float32)
|
||||
response = await sqlite_vec_index.query(query_embedding, k=1, score_threshold=0.0)
|
||||
response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0)
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) > 0
|
||||
assert len(response.chunks) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks):
|
||||
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
|
||||
# Reduce batch size to force multiple batches for same document
|
||||
# since there are 10 chunks per document and batch size is 2
|
||||
batch_size = 2
|
||||
sample_embeddings = np.random.rand(len(sample_chunks), EMBEDDING_DIMENSION).astype(np.float32)
|
||||
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)
|
||||
|
||||
cur = sqlite_vec_index.connection.cursor()
|
||||
|
||||
# Retrieve all chunk IDs to check for duplicates
|
||||
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
|
||||
chunk_ids = [row[0] for row in cur.fetchall()]
|
||||
cur.close()
|
||||
|
||||
# Ensure all chunk IDs are unique
|
||||
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -125,3 +143,18 @@ async def test_unregister_vector_db(sqlite_vec_adapter):
|
|||
await sqlite_vec_adapter.unregister_vector_db("test_db")
|
||||
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
|
||||
assert not any(db.identifier == "test_db" for db in vector_dbs)
|
||||
|
||||
|
||||
def test_generate_chunk_id():
|
||||
chunks = [
|
||||
Chunk(content="test", metadata={"document_id": "doc-1"}),
|
||||
Chunk(content="test ", metadata={"document_id": "doc-1"}),
|
||||
Chunk(content="test 3", metadata={"document_id": "doc-1"}),
|
||||
]
|
||||
|
||||
chunk_ids = sorted([generate_chunk_id(chunk.metadata["document_id"], chunk.content) for chunk in chunks])
|
||||
assert chunk_ids == [
|
||||
"177a1368-f6a8-0c50-6e92-18677f2c3de3",
|
||||
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
|
||||
"f68df25d-d9aa-ab4d-5684-64a233add20d",
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue