forked from phoenix-oss/llama-stack-mirror
# What does this PR do? Removed local execution option from the remote Qdrant provider and introduced an explicit inline provider for the embedded execution. Updated the ollama template to include this option: this part can be reverted in case we don't want to have two default `vector_io` providers. (Closes #1082) ## Test Plan Build and run an ollama distro: ```bash llama stack build --template ollama --image-type conda llama stack run --image-type conda ollama ``` Run one of the sample ingestionapplicatinos like [rag_with_vector_db.py](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/rag_with_vector_db.py), but replace this line: ```py selected_vector_provider = vector_providers[0] ``` with the following, to use the `qdrant` provider: ```py selected_vector_provider = vector_providers[1] ``` After running the test code, verify the timestamp of the Qdrant store: ```bash % ls -ltr ~/.llama/distributions/ollama/qdrant.db/collection/test_vector_db_* total 784 -rw-r--r--@ 1 dmartino staff 401408 Feb 26 10:07 storage.sqlite ``` [//]: # (## Documentation) --------- Signed-off-by: Daniele Martinoli <dmartino@redhat.com> Co-authored-by: Francisco Arceo <farceo@redhat.com>
135 lines
4.4 KiB
Python
135 lines
4.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import os
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from llama_stack.apis.inference import EmbeddingsResponse, Inference
|
|
from llama_stack.apis.vector_io import (
|
|
QueryChunksResponse,
|
|
VectorDB,
|
|
VectorDBStore,
|
|
)
|
|
from llama_stack.providers.inline.vector_io.qdrant.config import (
|
|
QdrantVectorIOConfig as InlineQdrantVectorIOConfig,
|
|
)
|
|
from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
|
|
QdrantVectorIOAdapter,
|
|
)
|
|
|
|
# This test is a unit test for the QdrantVectorIOAdapter class. This should only contain
|
|
# tests which are specific to this class. More general (API-level) tests should be placed in
|
|
# tests/integration/vector_io/
|
|
#
|
|
# How to run this test:
|
|
#
|
|
# pytest tests/unit/providers/vector_io/test_qdrant.py \
|
|
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
|
|
|
|
|
@pytest.fixture
|
|
def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig:
|
|
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"))
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def loop():
|
|
return asyncio.new_event_loop()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_vector_db(vector_db_id) -> MagicMock:
|
|
mock_vector_db = MagicMock(spec=VectorDB)
|
|
mock_vector_db.embedding_model = "embedding_model"
|
|
mock_vector_db.identifier = vector_db_id
|
|
return mock_vector_db
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_vector_db_store(mock_vector_db) -> MagicMock:
|
|
mock_store = MagicMock(spec=VectorDBStore)
|
|
mock_store.get_vector_db = AsyncMock(return_value=mock_vector_db)
|
|
return mock_store
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_api_service(sample_embeddings):
|
|
mock_api_service = MagicMock(spec=Inference)
|
|
mock_api_service.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings))
|
|
return mock_api_service
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
|
|
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service)
|
|
adapter.vector_db_store = mock_vector_db_store
|
|
await adapter.initialize()
|
|
yield adapter
|
|
await adapter.shutdown()
|
|
|
|
|
|
__QUERY = "Sample query"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 30)])
|
|
async def test_qdrant_adapter_returns_expected_chunks(
|
|
qdrant_adapter: QdrantVectorIOAdapter,
|
|
vector_db_id,
|
|
sample_chunks,
|
|
sample_embeddings,
|
|
max_query_chunks,
|
|
expected_chunks,
|
|
) -> None:
|
|
assert qdrant_adapter is not None
|
|
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
|
|
|
|
index = await qdrant_adapter._get_and_cache_vector_db_index(vector_db_id=vector_db_id)
|
|
assert index is not None
|
|
|
|
response = await qdrant_adapter.query_chunks(
|
|
query=__QUERY,
|
|
vector_db_id=vector_db_id,
|
|
params={"max_chunks": max_query_chunks},
|
|
)
|
|
assert isinstance(response, QueryChunksResponse)
|
|
assert len(response.chunks) == expected_chunks
|
|
|
|
|
|
# To by-pass attempt to convert a Mock to JSON
|
|
def _prepare_for_json(value: Any) -> str:
|
|
return str(value)
|
|
|
|
|
|
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
|
|
@pytest.mark.asyncio
|
|
async def test_qdrant_register_and_unregister_vector_db(
|
|
qdrant_adapter: QdrantVectorIOAdapter,
|
|
mock_vector_db,
|
|
sample_chunks,
|
|
) -> None:
|
|
# Initially, no collections
|
|
vector_db_id = mock_vector_db.identifier
|
|
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
|
|
|
|
# Register does not create a collection
|
|
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
|
await qdrant_adapter.register_vector_db(mock_vector_db)
|
|
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
|
|
|
# First insert creates the collection
|
|
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
|
|
assert await qdrant_adapter.client.collection_exists(vector_db_id)
|
|
|
|
# Unregister deletes the collection
|
|
await qdrant_adapter.unregister_vector_db(vector_db_id)
|
|
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
|
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
|