forked from phoenix-oss/llama-stack-mirror
# What does this PR do? This PR introduces support for keyword based FTS5 search with BM25 relevance scoring. It makes changes to the existing EmbeddingIndex base class in order to support a search_mode and query_str parameter, that can be used for keyword based search implementations. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan run ``` pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py -v -s --tb=short --disable-warnings --asyncio-mode=auto ``` Output: ``` pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py -v -s --tb=short --disable-warnings --asyncio-mode=auto /Users/vnarsing/miniconda3/envs/stack-client/lib/python3.10/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ====================================================== test session starts ======================================================= platform darwin -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /Users/vnarsing/miniconda3/envs/stack-client/bin/python cachedir: .pytest_cache metadata: {'Python': '3.10.16', 'Platform': 'macOS-14.7.4-arm64-arm-64bit', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'html': '4.1.1', 'metadata': '3.1.1', 'asyncio': '0.25.3', 'anyio': '4.8.0'}} rootdir: /Users/vnarsing/go/src/github/meta-llama/llama-stack configfile: pyproject.toml plugins: html-4.1.1, metadata-3.1.1, asyncio-0.25.3, anyio-4.8.0 asyncio: mode=auto, asyncio_default_fixture_loop_scope=None collected 7 items llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_add_chunks PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_query_chunks_vector PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_query_chunks_fts PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_chunk_id_conflict PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_register_vector_db PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_unregister_vector_db PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_generate_chunk_id PASSED ``` For reference, with the implementation, the fts table looks like below: ``` Chunk ID: 9fbc39ce-c729-64a2-260f-c5ec9bb2a33e, Content: Sentence 0 from document 0 Chunk ID: 94062914-3e23-44cf-1e50-9e25821ba882, Content: Sentence 1 from document 0 Chunk ID: e6cfd559-4641-33ba-6ce1-7038226495eb, Content: Sentence 2 from document 0 Chunk ID: 1383af9b-f1f0-f417-4de5-65fe9456cc20, Content: Sentence 3 from document 0 Chunk ID: 2db19b1a-de14-353b-f4e1-085e8463361c, Content: Sentence 4 from document 0 Chunk ID: 9faf986a-f028-7714-068a-1c795e8f2598, Content: Sentence 5 from document 0 Chunk ID: ef593ead-5a4a-392f-7ad8-471a50f033e8, Content: Sentence 6 from document 0 Chunk ID: e161950f-021f-7300-4d05-3166738b94cf, Content: Sentence 7 from document 0 Chunk ID: 90610fc4-67c1-e740-f043-709c5978867a, Content: Sentence 8 from document 0 Chunk ID: 97712879-6fff-98ad-0558-e9f42e6b81d3, Content: Sentence 9 from document 0 Chunk ID: aea70411-51df-61ba-d2f0-cb2b5972c210, Content: Sentence 0 from document 1 Chunk ID: b678a463-7b84-92b8-abb2-27e9a1977e3c, Content: Sentence 1 from document 1 Chunk ID: 27bd63da-909c-1606-a109-75bdb9479882, Content: Sentence 2 from document 1 Chunk ID: a2ad49ad-f9be-5372-e0c7-7b0221d0b53e, Content: Sentence 3 from document 1 Chunk ID: cac53bcd-1965-082a-c0f4-ceee7323fc70, Content: Sentence 4 from document 1 ``` Query results: Result 1: Sentence 5 from document 0 Result 2: Sentence 5 from document 1 Result 3: Sentence 5 from document 2 [//]: # (## Documentation) --------- Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
135 lines
4.4 KiB
Python
135 lines
4.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import os
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from llama_stack.apis.inference import EmbeddingsResponse, Inference
|
|
from llama_stack.apis.vector_io import (
|
|
QueryChunksResponse,
|
|
VectorDB,
|
|
VectorDBStore,
|
|
)
|
|
from llama_stack.providers.inline.vector_io.qdrant.config import (
|
|
QdrantVectorIOConfig as InlineQdrantVectorIOConfig,
|
|
)
|
|
from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
|
|
QdrantVectorIOAdapter,
|
|
)
|
|
|
|
# This test is a unit test for the QdrantVectorIOAdapter class. This should only contain
|
|
# tests which are specific to this class. More general (API-level) tests should be placed in
|
|
# tests/integration/vector_io/
|
|
#
|
|
# How to run this test:
|
|
#
|
|
# pytest tests/unit/providers/vector_io/test_qdrant.py \
|
|
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
|
|
|
|
|
@pytest.fixture
|
|
def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig:
|
|
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"))
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def loop():
|
|
return asyncio.new_event_loop()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_vector_db(vector_db_id) -> MagicMock:
|
|
mock_vector_db = MagicMock(spec=VectorDB)
|
|
mock_vector_db.embedding_model = "embedding_model"
|
|
mock_vector_db.identifier = vector_db_id
|
|
return mock_vector_db
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_vector_db_store(mock_vector_db) -> MagicMock:
|
|
mock_store = MagicMock(spec=VectorDBStore)
|
|
mock_store.get_vector_db = AsyncMock(return_value=mock_vector_db)
|
|
return mock_store
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_api_service(sample_embeddings):
|
|
mock_api_service = MagicMock(spec=Inference)
|
|
mock_api_service.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings))
|
|
return mock_api_service
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
|
|
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service)
|
|
adapter.vector_db_store = mock_vector_db_store
|
|
await adapter.initialize()
|
|
yield adapter
|
|
await adapter.shutdown()
|
|
|
|
|
|
__QUERY = "Sample query"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 30)])
|
|
async def test_qdrant_adapter_returns_expected_chunks(
|
|
qdrant_adapter: QdrantVectorIOAdapter,
|
|
vector_db_id,
|
|
sample_chunks,
|
|
sample_embeddings,
|
|
max_query_chunks,
|
|
expected_chunks,
|
|
) -> None:
|
|
assert qdrant_adapter is not None
|
|
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
|
|
|
|
index = await qdrant_adapter._get_and_cache_vector_db_index(vector_db_id=vector_db_id)
|
|
assert index is not None
|
|
|
|
response = await qdrant_adapter.query_chunks(
|
|
query=__QUERY,
|
|
vector_db_id=vector_db_id,
|
|
params={"max_chunks": max_query_chunks, "mode": "vector"},
|
|
)
|
|
assert isinstance(response, QueryChunksResponse)
|
|
assert len(response.chunks) == expected_chunks
|
|
|
|
|
|
# To by-pass attempt to convert a Mock to JSON
|
|
def _prepare_for_json(value: Any) -> str:
|
|
return str(value)
|
|
|
|
|
|
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
|
|
@pytest.mark.asyncio
|
|
async def test_qdrant_register_and_unregister_vector_db(
|
|
qdrant_adapter: QdrantVectorIOAdapter,
|
|
mock_vector_db,
|
|
sample_chunks,
|
|
) -> None:
|
|
# Initially, no collections
|
|
vector_db_id = mock_vector_db.identifier
|
|
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
|
|
|
|
# Register does not create a collection
|
|
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
|
await qdrant_adapter.register_vector_db(mock_vector_db)
|
|
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
|
|
|
# First insert creates the collection
|
|
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
|
|
assert await qdrant_adapter.client.collection_exists(vector_db_id)
|
|
|
|
# Unregister deletes the collection
|
|
await qdrant_adapter.unregister_vector_db(vector_db_id)
|
|
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
|
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
|