qdrant UT

Signed-off-by: Daniele Martinoli <dmartino@redhat.com>
This commit is contained in:
Daniele Martinoli 2025-03-17 19:05:16 +01:00
parent 3fa7129816
commit 1d3fccff20
5 changed files with 368 additions and 37 deletions

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import random
import numpy as np
import pytest
from llama_stack.apis.vector_io import Chunk
EMBEDDING_DIMENSION = 384
@pytest.fixture
def vector_db_id() -> str:
return f"test-vector-db-{random.randint(1, 100)}"
@pytest.fixture(scope="session")
def embedding_dimension() -> int:
return EMBEDDING_DIMENSION
@pytest.fixture(scope="session")
def sample_chunks():
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
n, k = 10, 3
sample = [
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
for j in range(k)
for i in range(n)
]
return sample
@pytest.fixture(scope="session")
def sample_embeddings(sample_chunks):
np.random.seed(42)
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])

View file

@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from llama_stack.apis.inference import EmbeddingsResponse, Inference
from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorDB,
VectorDBStore,
)
from llama_stack.providers.inline.vector_io.qdrant.config import (
QdrantVectorIOConfig as InlineQdrantVectorIOConfig,
)
from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
QdrantVectorIOAdapter,
)
# This test is a unit test for the QdrantVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_qdrant.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
@pytest.fixture
def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig:
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"))
import pytest
@pytest.fixture
def mock_vector_db(vector_db_id) -> MagicMock:
mock_vector_db = MagicMock(spec=VectorDB)
mock_vector_db.embedding_model = "embedding_model"
mock_vector_db.identifier = vector_db_id
return mock_vector_db
@pytest.fixture
def mock_vector_db_store(mock_vector_db) -> MagicMock:
mock_store = MagicMock(spec=VectorDBStore)
mock_store.get_vector_db = AsyncMock(return_value=mock_vector_db)
return mock_store
@pytest.fixture
def mock_api_service(sample_embeddings):
mock_api_service = MagicMock(spec=Inference)
mock_api_service.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings))
return mock_api_service
@pytest_asyncio.fixture
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service) -> QdrantVectorIOAdapter:
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service)
adapter.vector_db_store = mock_vector_db_store
await adapter.initialize()
yield adapter
await adapter.shutdown()
__QUERY = "Sample query"
@pytest.mark.asyncio
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 30)])
async def test_qdrant_adapter_returns_expected_chunks(
qdrant_adapter: QdrantVectorIOAdapter,
vector_db_id,
sample_chunks,
sample_embeddings,
max_query_chunks,
expected_chunks,
) -> None:
assert qdrant_adapter is not None
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
index = await qdrant_adapter._get_and_cache_vector_db_index(vector_db_id=vector_db_id)
assert index is not None
response = await qdrant_adapter.query_chunks(
query=__QUERY,
vector_db_id=vector_db_id,
params={"max_chunks": max_query_chunks},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == expected_chunks
# To by-pass attempt to convert a Mock to JSON
def _prepare_for_json(value: Any) -> str:
return str(value)
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
@pytest.mark.asyncio
async def test_qdrant_register_and_unregister_vector_db(
qdrant_adapter: QdrantVectorIOAdapter,
mock_vector_db,
sample_chunks,
) -> None:
# Initially, no collections
vector_db_id = mock_vector_db.identifier
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
# Register does not create a collection
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
await qdrant_adapter.register_vector_db(mock_vector_db)
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
# First insert creates the collection
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
assert await qdrant_adapter.client.collection_exists(vector_db_id)
# Unregister deletes the collection
await qdrant_adapter.unregister_vector_db(vector_db_id)
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
assert len((await qdrant_adapter.client.get_collections()).collections) == 0

View file

@ -29,7 +29,6 @@ from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
SQLITE_VEC_PROVIDER = "sqlite_vec"
EMBEDDING_DIMENSION = 384
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
@ -50,26 +49,8 @@ def sqlite_connection(loop):
@pytest_asyncio.fixture(scope="session", autouse=True)
async def sqlite_vec_index(sqlite_connection):
return await SQLiteVecIndex.create(dimension=EMBEDDING_DIMENSION, connection=sqlite_connection, bank_id="test_bank")
@pytest.fixture(scope="session")
def sample_chunks():
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
n, k = 10, 3
sample = [
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
for j in range(k)
for i in range(n)
]
return sample
@pytest.fixture(scope="session")
def sample_embeddings(sample_chunks):
np.random.seed(42)
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
async def sqlite_vec_index(sqlite_connection, embedding_dimension):
return await SQLiteVecIndex.create(dimension=embedding_dimension, connection=sqlite_connection, bank_id="test_bank")
@pytest.mark.asyncio
@ -82,21 +63,21 @@ async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
@pytest.mark.asyncio
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = np.random.rand(EMBEDDING_DIMENSION).astype(np.float32)
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
@pytest.mark.asyncio
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks):
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension):
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
# Reduce batch size to force multiple batches for same document
# since there are 10 chunks per document and batch size is 2
batch_size = 2
sample_embeddings = np.random.rand(len(sample_chunks), EMBEDDING_DIMENSION).astype(np.float32)
sample_embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)