llama-stack/tests/unit/providers/vector_io/test_sqlite_vec.py
Francisco Arceo 9e1ddf2b53
chore: Updating sqlite-vec to make non-blocking calls (#1762)
# What does this PR do?
This PR updates the sqlite-vec database calls to be non-blocking. Note
that each operation creates a new connection, which incurs some
performance overhead but is reasonable given [SQLite's threading and
connections constraints](https://www.sqlite.org/threadsafe.html).

Summary of changes:
- Refactored `SQLiteVecIndex` class to store database path instead of
connection object
- Added `_create_sqlite_connection()` helper function to create
connections on demand
- Ensured proper connection closure in all database operations
- Fixed test fixtures to use a file-based SQLite database for
thread-safety
- Updated the `SQLiteVecVectorIOAdapter` class to handle per-operation
connections

This PR helps chip away at
https://github.com/meta-llama/llama-stack/issues/1489

## Test Plan
sqlite-vec unit tests passed locally as well as a test script using the
client as a library.

## Misc

FYI @varshaprasad96 @kevincogan

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
2025-03-23 17:25:44 -07:00

111 lines
4.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import numpy as np
import pytest
import pytest_asyncio
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
SQLiteVecIndex,
SQLiteVecVectorIOAdapter,
_create_sqlite_connection,
generate_chunk_id,
)
# This test is a unit test for the SQLiteVecVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_sqlite_vec.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
SQLITE_VEC_PROVIDER = "sqlite_vec"
@pytest.fixture(scope="session")
def loop():
return asyncio.new_event_loop()
@pytest_asyncio.fixture(scope="session", autouse=True)
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_sqlite.db")
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
yield index
await index.delete()
@pytest.mark.asyncio
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
connection = _create_sqlite_connection(sqlite_vec_index.db_path)
cur = connection.cursor()
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
count = cur.fetchone()[0]
assert count == len(sample_chunks)
cur.close()
connection.close()
@pytest.mark.asyncio
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
@pytest.mark.asyncio
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension):
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
# Reduce batch size to force multiple batches for same document
# since there are 10 chunks per document and batch size is 2
batch_size = 2
sample_embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)
connection = _create_sqlite_connection(sqlite_vec_index.db_path)
cur = connection.cursor()
# Retrieve all chunk IDs to check for duplicates
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
chunk_ids = [row[0] for row in cur.fetchall()]
cur.close()
connection.close()
# Ensure all chunk IDs are unique
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
@pytest.fixture(scope="session")
async def sqlite_vec_adapter(sqlite_connection):
config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database
adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None)
await adapter.initialize()
yield adapter
await adapter.shutdown()
def test_generate_chunk_id():
chunks = [
Chunk(content="test", metadata={"document_id": "doc-1"}),
Chunk(content="test ", metadata={"document_id": "doc-1"}),
Chunk(content="test 3", metadata={"document_id": "doc-1"}),
]
chunk_ids = sorted([generate_chunk_id(chunk.metadata["document_id"], chunk.content) for chunk in chunks])
assert chunk_ids == [
"177a1368-f6a8-0c50-6e92-18677f2c3de3",
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
"f68df25d-d9aa-ab4d-5684-64a233add20d",
]