mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
updating to use a hash of document id and text content to guarantee uniqueness
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
5a6c95ecf9
commit
898d325772
2 changed files with 68 additions and 27 deletions
|
@ -4,9 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import struct
|
import struct
|
||||||
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -80,13 +82,13 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
try:
|
try:
|
||||||
# Start transaction
|
# Start transaction
|
||||||
cur.execute("BEGIN TRANSACTION")
|
cur.execute("BEGIN TRANSACTION")
|
||||||
for i in range(0, len(chunks), batch_size):
|
for k, i in enumerate(range(0, len(chunks), batch_size)):
|
||||||
batch_chunks = chunks[i : i + batch_size]
|
batch_chunks = chunks[i : i + batch_size]
|
||||||
batch_embeddings = embeddings[i : i + batch_size]
|
batch_embeddings = embeddings[i : i + batch_size]
|
||||||
# Prepare metadata inserts
|
# Prepare metadata inserts
|
||||||
metadata_data = [
|
metadata_data = [
|
||||||
(f"{chunk.metadata['document_id']}:chunk-{j}", chunk.model_dump_json())
|
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
||||||
for j, chunk in enumerate(batch_chunks)
|
for chunk in batch_chunks
|
||||||
]
|
]
|
||||||
# Insert metadata (ON CONFLICT to avoid duplicates)
|
# Insert metadata (ON CONFLICT to avoid duplicates)
|
||||||
cur.executemany(
|
cur.executemany(
|
||||||
|
@ -99,8 +101,8 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
)
|
)
|
||||||
# Prepare embeddings inserts
|
# Prepare embeddings inserts
|
||||||
embedding_data = [
|
embedding_data = [
|
||||||
(f"{chunk.metadata['document_id']}:chunk-{j}", serialize_vector(emb.tolist()))
|
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
|
||||||
for j, (chunk, emb) in enumerate(zip(batch_chunks, batch_embeddings, strict=True))
|
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||||
]
|
]
|
||||||
# Insert embeddings in batch
|
# Insert embeddings in batch
|
||||||
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
|
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
|
||||||
|
@ -227,3 +229,9 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
if vector_db_id not in self.cache:
|
if vector_db_id not in self.cache:
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
return await self.cache[vector_db_id].query_chunks(query, params)
|
return await self.cache[vector_db_id].query_chunks(query, params)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||||
|
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
||||||
|
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
|
||||||
|
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
|
||||||
|
|
|
@ -13,7 +13,11 @@ import sqlite_vec
|
||||||
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
|
||||||
|
SQLiteVecIndex,
|
||||||
|
SQLiteVecVectorIOAdapter,
|
||||||
|
generate_chunk_id,
|
||||||
|
)
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
#
|
#
|
||||||
|
@ -46,34 +50,27 @@ async def sqlite_vec_index(sqlite_connection):
|
||||||
return await SQLiteVecIndex.create(dimension=EMBEDDING_DIMENSION, connection=sqlite_connection, bank_id="test_bank")
|
return await SQLiteVecIndex.create(dimension=EMBEDDING_DIMENSION, connection=sqlite_connection, bank_id="test_bank")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(scope="session")
|
||||||
def sample_chunks():
|
def sample_chunks():
|
||||||
return [
|
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
|
||||||
Chunk(
|
n, k = 10, 3
|
||||||
content="Python is a high-level programming language.",
|
sample = [
|
||||||
metadata={"category": "programming", "document_id": "doc 1"},
|
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
|
||||||
),
|
for j in range(k)
|
||||||
Chunk(
|
for i in range(n)
|
||||||
content="Machine learning is a subset of artificial intelligence.",
|
|
||||||
metadata={"category": "AI", "document_id": "doc 1"},
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(scope="session")
|
||||||
def sample_embeddings():
|
def sample_embeddings(sample_chunks):
|
||||||
np.random.seed(42)
|
np.random.seed(42)
|
||||||
return np.array(
|
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
|
||||||
[
|
|
||||||
np.random.rand(EMBEDDING_DIMENSION).astype(np.float32),
|
|
||||||
np.random.rand(EMBEDDING_DIMENSION).astype(np.float32),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
|
||||||
cur = sqlite_vec_index.connection.cursor()
|
cur = sqlite_vec_index.connection.cursor()
|
||||||
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
|
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
|
||||||
count = cur.fetchone()[0]
|
count = cur.fetchone()[0]
|
||||||
|
@ -84,9 +81,30 @@ async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
query_embedding = np.random.rand(EMBEDDING_DIMENSION).astype(np.float32)
|
query_embedding = np.random.rand(EMBEDDING_DIMENSION).astype(np.float32)
|
||||||
response = await sqlite_vec_index.query(query_embedding, k=1, score_threshold=0.0)
|
response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0)
|
||||||
assert isinstance(response, QueryChunksResponse)
|
assert isinstance(response, QueryChunksResponse)
|
||||||
assert len(response.chunks) > 0
|
assert len(response.chunks) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks):
|
||||||
|
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
|
||||||
|
# Reduce batch size to force multiple batches for same document
|
||||||
|
# since there are 10 chunks per document and batch size is 2
|
||||||
|
batch_size = 2
|
||||||
|
sample_embeddings = np.random.rand(len(sample_chunks), EMBEDDING_DIMENSION).astype(np.float32)
|
||||||
|
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)
|
||||||
|
|
||||||
|
cur = sqlite_vec_index.connection.cursor()
|
||||||
|
|
||||||
|
# Retrieve all chunk IDs to check for duplicates
|
||||||
|
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
|
||||||
|
chunk_ids = [row[0] for row in cur.fetchall()]
|
||||||
|
cur.close()
|
||||||
|
|
||||||
|
# Ensure all chunk IDs are unique
|
||||||
|
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -125,3 +143,18 @@ async def test_unregister_vector_db(sqlite_vec_adapter):
|
||||||
await sqlite_vec_adapter.unregister_vector_db("test_db")
|
await sqlite_vec_adapter.unregister_vector_db("test_db")
|
||||||
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
|
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
|
||||||
assert not any(db.identifier == "test_db" for db in vector_dbs)
|
assert not any(db.identifier == "test_db" for db in vector_dbs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_chunk_id():
|
||||||
|
chunks = [
|
||||||
|
Chunk(content="test", metadata={"document_id": "doc-1"}),
|
||||||
|
Chunk(content="test ", metadata={"document_id": "doc-1"}),
|
||||||
|
Chunk(content="test 3", metadata={"document_id": "doc-1"}),
|
||||||
|
]
|
||||||
|
|
||||||
|
chunk_ids = sorted([generate_chunk_id(chunk.metadata["document_id"], chunk.content) for chunk in chunks])
|
||||||
|
assert chunk_ids == [
|
||||||
|
"177a1368-f6a8-0c50-6e92-18677f2c3de3",
|
||||||
|
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
|
||||||
|
"f68df25d-d9aa-ab4d-5684-64a233add20d",
|
||||||
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue