mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: Add ChunkMetadata to Chunk (#2497)
# What does this PR do? Adding `ChunkMetadata` so we can properly delete embeddings later. More specifically, this PR refactors and extends the chunk metadata handling in the vector database and introduces a distinction between metadata used for model context and backend-only metadata required for chunk management, storage, and retrieval. It also improves chunk ID generation and propagation throughout the stack, enhances test coverage, and adds new utility modules. ```python class ChunkMetadata(BaseModel): """ `ChunkMetadata` is backend metadata for a `Chunk` that is used to store additional information about the chunk that will NOT be inserted into the context during inference, but is required for backend functionality. Use `metadata` in `Chunk` for metadata that will be used during inference. """ document_id: str | None = None chunk_id: str | None = None source: str | None = None created_timestamp: int | None = None updated_timestamp: int | None = None chunk_window: str | None = None chunk_tokenizer: str | None = None chunk_embedding_model: str | None = None chunk_embedding_dimension: int | None = None content_token_count: int | None = None metadata_token_count: int | None = None ``` Eventually we can migrate the document_id out of the `metadata` field. I've introduced the changes so that `ChunkMetadata` is backwards compatible with `metadata`. <!-- If resolving an issue, uncomment and update the line below --> Closes https://github.com/meta-llama/llama-stack/issues/2501 ## Test Plan Added unit tests --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
fa0b0c13d4
commit
82f13fe83e
14 changed files with 490 additions and 218 deletions
|
@ -9,7 +9,7 @@ import random
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_io import Chunk
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
|
||||
|
||||
EMBEDDING_DIMENSION = 384
|
||||
|
||||
|
@ -33,6 +33,20 @@ def sample_chunks():
|
|||
for j in range(k)
|
||||
for i in range(n)
|
||||
]
|
||||
sample.extend(
|
||||
[
|
||||
Chunk(
|
||||
content=f"Sentence {i} from document {j + k}",
|
||||
chunk_metadata=ChunkMetadata(
|
||||
document_id=f"document-{j + k}",
|
||||
chunk_id=f"document-{j}-chunk-{i}",
|
||||
source=f"example source-{j + k}-{i}",
|
||||
),
|
||||
)
|
||||
for j in range(k)
|
||||
for i in range(n)
|
||||
]
|
||||
)
|
||||
return sample
|
||||
|
||||
|
||||
|
|
66
tests/unit/providers/vector_io/test_chunk_utils.py
Normal file
66
tests/unit/providers/vector_io/test_chunk_utils.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
|
||||
from llama_stack.providers.utils.vector_io.chunk_utils import generate_chunk_id
|
||||
|
||||
# This test is a unit test for the chunk_utils.py helpers. This should only contain
|
||||
# tests which are specific to this file. More general (API-level) tests should be placed in
|
||||
# tests/integration/vector_io/
|
||||
#
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest tests/unit/providers/vector_io/test_chunk_utils.py \
|
||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
|
||||
def test_generate_chunk_id():
|
||||
chunks = [
|
||||
Chunk(content="test", metadata={"document_id": "doc-1"}),
|
||||
Chunk(content="test ", metadata={"document_id": "doc-1"}),
|
||||
Chunk(content="test 3", metadata={"document_id": "doc-1"}),
|
||||
]
|
||||
|
||||
chunk_ids = sorted([chunk.chunk_id for chunk in chunks])
|
||||
assert chunk_ids == [
|
||||
"177a1368-f6a8-0c50-6e92-18677f2c3de3",
|
||||
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
|
||||
"f68df25d-d9aa-ab4d-5684-64a233add20d",
|
||||
]
|
||||
|
||||
|
||||
def test_chunk_id():
|
||||
# Test with existing chunk ID
|
||||
chunk_with_id = Chunk(content="test", metadata={"document_id": "existing-id"})
|
||||
assert chunk_with_id.chunk_id == "84ededcc-b80b-a83e-1a20-ca6515a11350"
|
||||
|
||||
# Test with document ID in metadata
|
||||
chunk_with_doc_id = Chunk(content="test", metadata={"document_id": "doc-1"})
|
||||
assert chunk_with_doc_id.chunk_id == generate_chunk_id("doc-1", "test")
|
||||
|
||||
# Test chunks with ChunkMetadata
|
||||
chunk_with_metadata = Chunk(
|
||||
content="test",
|
||||
metadata={"document_id": "existing-id", "chunk_id": "chunk-id-1"},
|
||||
chunk_metadata=ChunkMetadata(document_id="document_1"),
|
||||
)
|
||||
assert chunk_with_metadata.chunk_id == "chunk-id-1"
|
||||
|
||||
# Test with no ID or document ID
|
||||
chunk_without_id = Chunk(content="test")
|
||||
generated_id = chunk_without_id.chunk_id
|
||||
assert isinstance(generated_id, str) and len(generated_id) == 36 # Should be a valid UUID
|
||||
|
||||
|
||||
def test_stored_chunk_id_alias():
|
||||
# Test with existing chunk ID alias
|
||||
chunk_with_alias = Chunk(content="test", metadata={"document_id": "existing-id", "chunk_id": "chunk-id-1"})
|
||||
assert chunk_with_alias.chunk_id == "chunk-id-1"
|
||||
serialized_chunk = chunk_with_alias.model_dump()
|
||||
assert serialized_chunk["stored_chunk_id"] == "chunk-id-1"
|
||||
# showing chunk_id is not serialized (i.e., a computed field)
|
||||
assert "chunk_id" not in serialized_chunk
|
||||
assert chunk_with_alias.stored_chunk_id == "chunk-id-1"
|
|
@ -81,7 +81,7 @@ __QUERY = "Sample query"
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 30)])
|
||||
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 60)])
|
||||
async def test_qdrant_adapter_returns_expected_chunks(
|
||||
qdrant_adapter: QdrantVectorIOAdapter,
|
||||
vector_db_id,
|
||||
|
|
|
@ -15,7 +15,6 @@ from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
|
|||
SQLiteVecIndex,
|
||||
SQLiteVecVectorIOAdapter,
|
||||
_create_sqlite_connection,
|
||||
generate_chunk_id,
|
||||
)
|
||||
|
||||
# This test is a unit test for the SQLiteVecVectorIOAdapter class. This should only contain
|
||||
|
@ -65,6 +64,14 @@ async def test_query_chunks_vector(sqlite_vec_index, sample_chunks, sample_embed
|
|||
assert len(response.chunks) == 2
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Chunk Metadata not yet supported for SQLite-vec", strict=True)
|
||||
async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_embedding = sample_embeddings[0]
|
||||
response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
||||
assert response.chunks[-1].chunk_metadata == sample_chunks[-1].chunk_metadata
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
@ -150,21 +157,6 @@ async def sqlite_vec_adapter(sqlite_connection):
|
|||
await adapter.shutdown()
|
||||
|
||||
|
||||
def test_generate_chunk_id():
|
||||
chunks = [
|
||||
Chunk(content="test", metadata={"document_id": "doc-1"}),
|
||||
Chunk(content="test ", metadata={"document_id": "doc-1"}),
|
||||
Chunk(content="test 3", metadata={"document_id": "doc-1"}),
|
||||
]
|
||||
|
||||
chunk_ids = sorted([generate_chunk_id(chunk.metadata["document_id"], chunk.content) for chunk in chunks])
|
||||
assert chunk_ids == [
|
||||
"177a1368-f6a8-0c50-6e92-18677f2c3de3",
|
||||
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
|
||||
"f68df25d-d9aa-ab4d-5684-64a233add20d",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
"""Test hybrid search when keyword search returns no matches - should still return vector results."""
|
||||
|
@ -339,7 +331,7 @@ async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks
|
|||
# Verify scores are in descending order
|
||||
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||
# Verify we get results from both the vector-similar document and keyword-matched document
|
||||
doc_ids = {chunk.metadata["document_id"] for chunk in response.chunks}
|
||||
doc_ids = {chunk.metadata.get("document_id") or chunk.chunk_metadata.document_id for chunk in response.chunks}
|
||||
assert "document-0" in doc_ids # From vector search
|
||||
assert "document-2" in doc_ids # From keyword search
|
||||
|
||||
|
@ -364,7 +356,11 @@ async def test_query_chunks_hybrid_weighted_reranker_parametrization(
|
|||
reranker_params={"alpha": 1.0},
|
||||
)
|
||||
assert len(response.chunks) > 0 # Should get at least one result
|
||||
assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
|
||||
assert any(
|
||||
"document-0"
|
||||
in (chunk.metadata.get("document_id") or (chunk.chunk_metadata.document_id if chunk.chunk_metadata else ""))
|
||||
for chunk in response.chunks
|
||||
)
|
||||
|
||||
# alpha=0.0 (should behave like pure vector)
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
|
@ -389,7 +385,11 @@ async def test_query_chunks_hybrid_weighted_reranker_parametrization(
|
|||
reranker_params={"alpha": 0.7},
|
||||
)
|
||||
assert len(response.chunks) > 0 # Should get at least one result
|
||||
assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
|
||||
assert any(
|
||||
"document-0"
|
||||
in (chunk.metadata.get("document_id") or (chunk.chunk_metadata.document_id if chunk.chunk_metadata else ""))
|
||||
for chunk in response.chunks
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
@ -4,10 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
ChunkMetadata,
|
||||
QueryChunksResponse,
|
||||
)
|
||||
from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
|
||||
|
||||
|
||||
|
@ -17,3 +22,41 @@ class TestRagQuery:
|
|||
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
|
||||
with pytest.raises(ValueError):
|
||||
await rag_tool.query(content=MagicMock(), vector_db_ids=[])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunk_metadata_handling(self):
|
||||
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
|
||||
content = "test query content"
|
||||
vector_db_ids = ["db1"]
|
||||
|
||||
chunk_metadata = ChunkMetadata(
|
||||
document_id="doc1",
|
||||
chunk_id="chunk1",
|
||||
source="test_source",
|
||||
metadata_token_count=5,
|
||||
)
|
||||
interleaved_content = MagicMock()
|
||||
chunk = Chunk(
|
||||
content=interleaved_content,
|
||||
metadata={
|
||||
"key1": "value1",
|
||||
"token_count": 10,
|
||||
"metadata_token_count": 5,
|
||||
# Note this is inserted into `metadata` during MemoryToolRuntimeImpl().insert()
|
||||
"document_id": "doc1",
|
||||
},
|
||||
stored_chunk_id="chunk1",
|
||||
chunk_metadata=chunk_metadata,
|
||||
)
|
||||
|
||||
query_response = QueryChunksResponse(chunks=[chunk], scores=[1.0])
|
||||
|
||||
rag_tool.vector_io_api.query_chunks = AsyncMock(return_value=query_response)
|
||||
result = await rag_tool.query(content=content, vector_db_ids=vector_db_ids)
|
||||
|
||||
assert result is not None
|
||||
expected_metadata_string = (
|
||||
"Metadata: {'chunk_id': 'chunk1', 'document_id': 'doc1', 'source': 'test_source', 'key1': 'value1'}"
|
||||
)
|
||||
assert expected_metadata_string in result.content[1].text
|
||||
assert result.content is not None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue