mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
# What does this PR do? Adding `ChunkMetadata` so we can properly delete embeddings later. More specifically, this PR refactors and extends the chunk metadata handling in the vector database and introduces a distinction between metadata used for model context and backend-only metadata required for chunk management, storage, and retrieval. It also improves chunk ID generation and propagation throughout the stack, enhances test coverage, and adds new utility modules. ```python class ChunkMetadata(BaseModel): """ `ChunkMetadata` is backend metadata for a `Chunk` that is used to store additional information about the chunk that will NOT be inserted into the context during inference, but is required for backend functionality. Use `metadata` in `Chunk` for metadata that will be used during inference. """ document_id: str | None = None chunk_id: str | None = None source: str | None = None created_timestamp: int | None = None updated_timestamp: int | None = None chunk_window: str | None = None chunk_tokenizer: str | None = None chunk_embedding_model: str | None = None chunk_embedding_dimension: int | None = None content_token_count: int | None = None metadata_token_count: int | None = None ``` Eventually we can migrate the document_id out of the `metadata` field. I've introduced the changes so that `ChunkMetadata` is backwards compatible with `metadata`. <!-- If resolving an issue, uncomment and update the line below --> Closes https://github.com/meta-llama/llama-stack/issues/2501 ## Test Plan Added unit tests --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
136 lines
4.4 KiB
Python
136 lines
4.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import os
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from llama_stack.apis.inference import EmbeddingsResponse, Inference
|
|
from llama_stack.apis.vector_io import (
|
|
QueryChunksResponse,
|
|
VectorDB,
|
|
VectorDBStore,
|
|
)
|
|
from llama_stack.providers.inline.vector_io.qdrant.config import (
|
|
QdrantVectorIOConfig as InlineQdrantVectorIOConfig,
|
|
)
|
|
from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
|
|
QdrantVectorIOAdapter,
|
|
)
|
|
|
|
# This test is a unit test for the QdrantVectorIOAdapter class. This should only contain
|
|
# tests which are specific to this class. More general (API-level) tests should be placed in
|
|
# tests/integration/vector_io/
|
|
#
|
|
# How to run this test:
|
|
#
|
|
# pytest tests/unit/providers/vector_io/test_qdrant.py \
|
|
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
|
|
|
|
|
@pytest.fixture
|
|
def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig:
|
|
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"))
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def loop():
|
|
return asyncio.new_event_loop()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_vector_db(vector_db_id) -> MagicMock:
|
|
mock_vector_db = MagicMock(spec=VectorDB)
|
|
mock_vector_db.embedding_model = "embedding_model"
|
|
mock_vector_db.identifier = vector_db_id
|
|
mock_vector_db.embedding_dimension = 384
|
|
return mock_vector_db
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_vector_db_store(mock_vector_db) -> MagicMock:
|
|
mock_store = MagicMock(spec=VectorDBStore)
|
|
mock_store.get_vector_db = AsyncMock(return_value=mock_vector_db)
|
|
return mock_store
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_api_service(sample_embeddings):
|
|
mock_api_service = MagicMock(spec=Inference)
|
|
mock_api_service.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings))
|
|
return mock_api_service
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
|
|
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service)
|
|
adapter.vector_db_store = mock_vector_db_store
|
|
await adapter.initialize()
|
|
yield adapter
|
|
await adapter.shutdown()
|
|
|
|
|
|
__QUERY = "Sample query"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 60)])
|
|
async def test_qdrant_adapter_returns_expected_chunks(
|
|
qdrant_adapter: QdrantVectorIOAdapter,
|
|
vector_db_id,
|
|
sample_chunks,
|
|
sample_embeddings,
|
|
max_query_chunks,
|
|
expected_chunks,
|
|
) -> None:
|
|
assert qdrant_adapter is not None
|
|
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
|
|
|
|
index = await qdrant_adapter._get_and_cache_vector_db_index(vector_db_id=vector_db_id)
|
|
assert index is not None
|
|
|
|
response = await qdrant_adapter.query_chunks(
|
|
query=__QUERY,
|
|
vector_db_id=vector_db_id,
|
|
params={"max_chunks": max_query_chunks, "mode": "vector"},
|
|
)
|
|
assert isinstance(response, QueryChunksResponse)
|
|
assert len(response.chunks) == expected_chunks
|
|
|
|
|
|
# To by-pass attempt to convert a Mock to JSON
|
|
def _prepare_for_json(value: Any) -> str:
|
|
return str(value)
|
|
|
|
|
|
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
|
|
@pytest.mark.asyncio
|
|
async def test_qdrant_register_and_unregister_vector_db(
|
|
qdrant_adapter: QdrantVectorIOAdapter,
|
|
mock_vector_db,
|
|
sample_chunks,
|
|
) -> None:
|
|
# Initially, no collections
|
|
vector_db_id = mock_vector_db.identifier
|
|
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
|
|
|
|
# Register does not create a collection
|
|
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
|
await qdrant_adapter.register_vector_db(mock_vector_db)
|
|
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
|
|
|
# First insert creates the collection
|
|
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
|
|
assert await qdrant_adapter.client.collection_exists(vector_db_id)
|
|
|
|
# Unregister deletes the collection
|
|
await qdrant_adapter.unregister_vector_db(vector_db_id)
|
|
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
|
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
|