feat: Add ChunkMetadata to Chunk (#2497)

# What does this PR do?
Adding `ChunkMetadata` so we can properly delete embeddings later.

More specifically, this PR refactors and extends the chunk metadata
handling in the vector database and introduces a distinction between
metadata used for model context and backend-only metadata required for
chunk management, storage, and retrieval. It also improves chunk ID
generation and propagation throughout the stack, enhances test coverage,
and adds new utility modules.

```python
class ChunkMetadata(BaseModel):
    """
    `ChunkMetadata` is backend metadata for a `Chunk` that is used to store additional information about the chunk that
        will NOT be inserted into the context during inference, but is required for backend functionality.
        Use `metadata` in `Chunk` for metadata that will be used during inference.
    """
    document_id: str | None = None
    chunk_id: str | None = None
    source: str | None = None
    created_timestamp: int | None = None
    updated_timestamp: int | None = None
    chunk_window: str | None = None
    chunk_tokenizer: str | None = None
    chunk_embedding_model: str | None = None
    chunk_embedding_dimension: int | None = None
    content_token_count: int | None = None
    metadata_token_count: int | None = None
```
Eventually we can migrate the document_id out of the `metadata` field.
I've introduced the changes so that `ChunkMetadata` is backwards
compatible with `metadata`.

<!-- If resolving an issue, uncomment and update the line below -->
Closes https://github.com/meta-llama/llama-stack/issues/2501 

## Test Plan
Added unit tests

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Arceo 2025-06-25 13:55:23 -06:00 committed by GitHub
parent fa0b0c13d4
commit 82f13fe83e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 490 additions and 218 deletions

View file

@ -81,6 +81,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
chunks = []
for doc in documents:
content = await content_from_doc(doc)
# TODO: we should add enrichment here as URLs won't be added to the metadata by default
chunks.extend(
make_overlapped_chunks(
doc.document_id,
@ -157,8 +158,24 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
)
break
metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]}
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset)
# Add useful keys from chunk_metadata to metadata and remove some from metadata
chunk_metadata_keys_to_include_from_context = [
"chunk_id",
"document_id",
"source",
]
metadata_keys_to_exclude_from_context = [
"token_count",
"metadata_token_count",
]
metadata_for_context = {}
for k in chunk_metadata_keys_to_include_from_context:
metadata_for_context[k] = getattr(chunk.chunk_metadata, k)
for k in metadata:
if k not in metadata_keys_to_exclude_from_context:
metadata_for_context[k] = metadata[k]
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_for_context)
picked.append(TextContentItem(text=text_content))
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))

View file

@ -5,12 +5,10 @@
# the root directory of this source tree.
import asyncio
import hashlib
import json
import logging
import sqlite3
import struct
import uuid
from typing import Any
import numpy as np
@ -201,10 +199,7 @@ class SQLiteVecIndex(EmbeddingIndex):
batch_embeddings = embeddings[i : i + batch_size]
# Insert metadata
metadata_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
for chunk in batch_chunks
]
metadata_data = [(chunk.chunk_id, chunk.model_dump_json()) for chunk in batch_chunks]
cur.executemany(
f"""
INSERT INTO {self.metadata_table} (id, chunk)
@ -218,7 +213,7 @@ class SQLiteVecIndex(EmbeddingIndex):
embedding_data = [
(
(
generate_chunk_id(chunk.metadata["document_id"], chunk.content),
chunk.chunk_id,
serialize_vector(emb.tolist()),
)
)
@ -230,10 +225,7 @@ class SQLiteVecIndex(EmbeddingIndex):
)
# Insert FTS content
fts_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.content)
for chunk in batch_chunks
]
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
cur.executemany(
f"DELETE FROM {self.fts_table} WHERE id = ?;",
@ -381,13 +373,12 @@ class SQLiteVecIndex(EmbeddingIndex):
vector_response = await self.query_vector(embedding, k, score_threshold)
keyword_response = await self.query_keyword(query_string, k, score_threshold)
# Convert responses to score dictionaries using generate_chunk_id
# Convert responses to score dictionaries using chunk_id
vector_scores = {
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
}
keyword_scores = {
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
chunk.chunk_id: score
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
}
@ -408,13 +399,7 @@ class SQLiteVecIndex(EmbeddingIndex):
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
# Create a map of chunk_id to chunk for both responses
chunk_map = {}
for c in vector_response.chunks:
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
chunk_map[chunk_id] = c
for c in keyword_response.chunks:
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
chunk_map[chunk_id] = c
chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
# Use the map to look up chunks by their IDs
chunks = []
@ -757,9 +742,3 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params)
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
hash_input = f"{document_id}:{chunk_text}".encode()
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))