feat: Adding ChunkMetadata

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-06-23 14:59:11 -04:00
parent 6fde601765
commit f90fce218e
13 changed files with 416 additions and 206 deletions

View file

@ -19,17 +19,52 @@ from llama_stack.schema_utils import json_schema_type, webmethod
from llama_stack.strong_typing.schema import register_schema
@json_schema_type
class ChunkMetadata(BaseModel):
"""
`ChunkMetadata` is backend metadata for a `Chunk` that is used to store additional information about the chunk that
will NOT be inserted into the context during inference, but is required for backend functionality.
Use `metadata` in `Chunk` for metadata that will be used during inference.
:param document_id: The ID of the document this chunk belongs to.
:param source: The source of the content, such as a URL or file path.
:param created_timestamp: An optional timestamp indicating when the chunk was created.
:param updated_timestamp: An optional timestamp indicating when the chunk was last updated.
:param chunk_window: The window of the chunk, which can be used to group related chunks together.
:param chunk_tokenizer: The tokenizer used to create the chunk. Default is Tiktoken.
:param chunk_embedding_model: The embedding model used to create the chunk's embedding.
:param chunk_embedding_dimension: The dimension of the embedding vector for the chunk.
:param content_token_count: The number of tokens in the content of the chunk.
:param metadata_token_count: The number of tokens in the metadata of the chunk.
"""
document_id: str | None = None
chunk_id: str | None = None
source: str | None = None
created_timestamp: int | None = None
updated_timestamp: int | None = None
chunk_window: str | None = None
chunk_tokenizer: str | None = None
chunk_embedding_model: str | None = None
chunk_embedding_dimension: int | None = None
content_token_count: int | None = None
metadata_token_count: int | None = None
@json_schema_type
class Chunk(BaseModel):
"""
A chunk of content that can be inserted into a vector database.
:param content: The content of the chunk, which can be interleaved text, images, or other types.
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
:param metadata: Metadata associated with the chunk, such as document ID, source, or other relevant information.
:param metadata: Metadata associated with the chunk that will be used during inference.
:param chunk_metadata: Metadata for the chunk that will NOT be inserted into the context during inference
that is required backend functionality.
"""
content: InterleavedContent
metadata: dict[str, Any] = Field(default_factory=dict)
embedding: list[float] | None = None
chunk_metadata: ChunkMetadata | None = None
@json_schema_type

View file

@ -148,6 +148,9 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
]
for i, chunk in enumerate(chunks):
metadata = chunk.metadata
# update chunk.metadata with the chunk.chunk_metadata if it exists
if chunk.chunk_metadata:
metadata = {**metadata, **chunk.chunk_metadata.dict()}
tokens += metadata.get("token_count", 0)
tokens += metadata.get("metadata_token_count", 0)
@ -157,7 +160,19 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
)
break
metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]}
metadata_fields_to_exclude_from_context = [
"chunk_tokenizer",
"chunk_window",
"token_count",
"metadata_token_count",
"chunk_tokenizer",
"chunk_embedding_model",
"created_timestamp",
"updated_timestamp",
"chunk_window",
"content_token_count",
]
metadata_subset = {k: v for k, v in metadata.items() if k not in metadata_fields_to_exclude_from_context}
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset)
picked.append(TextContentItem(text=text_content))

View file

@ -5,12 +5,10 @@
# the root directory of this source tree.
import asyncio
import hashlib
import json
import logging
import sqlite3
import struct
import uuid
from typing import Any
import numpy as np
@ -33,6 +31,7 @@ from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
)
from llama_stack.providers.utils.vector_io.chunk_utils import extract_or_generate_chunk_id
logger = logging.getLogger(__name__)
@ -202,8 +201,7 @@ class SQLiteVecIndex(EmbeddingIndex):
# Insert metadata
metadata_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
for chunk in batch_chunks
(extract_or_generate_chunk_id(chunk), chunk.model_dump_json()) for chunk in batch_chunks
]
cur.executemany(
f"""
@ -218,7 +216,7 @@ class SQLiteVecIndex(EmbeddingIndex):
embedding_data = [
(
(
generate_chunk_id(chunk.metadata["document_id"], chunk.content),
extract_or_generate_chunk_id(chunk),
serialize_vector(emb.tolist()),
)
)
@ -230,10 +228,7 @@ class SQLiteVecIndex(EmbeddingIndex):
)
# Insert FTS content
fts_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.content)
for chunk in batch_chunks
]
fts_data = [(extract_or_generate_chunk_id(chunk), chunk.content) for chunk in batch_chunks]
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
cur.executemany(
f"DELETE FROM {self.fts_table} WHERE id = ?;",
@ -383,11 +378,11 @@ class SQLiteVecIndex(EmbeddingIndex):
# Convert responses to score dictionaries using generate_chunk_id
vector_scores = {
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
extract_or_generate_chunk_id(chunk): score
for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
}
keyword_scores = {
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
extract_or_generate_chunk_id(chunk): score
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
}
@ -410,10 +405,10 @@ class SQLiteVecIndex(EmbeddingIndex):
# Create a map of chunk_id to chunk for both responses
chunk_map = {}
for c in vector_response.chunks:
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
chunk_id = extract_or_generate_chunk_id(c)
chunk_map[chunk_id] = c
for c in keyword_response.chunks:
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
chunk_id = extract_or_generate_chunk_id(c)
chunk_map[chunk_id] = c
# Use the map to look up chunks by their IDs
@ -757,9 +752,3 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params)
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
hash_input = f"{document_id}:{chunk_text}".encode()
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))

View file

@ -72,7 +72,11 @@ class QdrantIndex(EmbeddingIndex):
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)):
chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}"
chunk_id = (
f"{chunk.metadata.get('document_id')}:chunk-{i}"
if chunk.metadata
else f"{chunk.chunk_metadata.document_id}:chunk-{i}"
)
points.append(
PointStruct(
id=convert_id(chunk_id),

View file

@ -7,6 +7,7 @@ import base64
import io
import logging
import re
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
@ -23,12 +24,13 @@ from llama_stack.apis.common.content_types import (
)
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack.providers.utils.vector_io.chunk_utils import generate_chunk_id
log = logging.getLogger(__name__)
@ -148,6 +150,10 @@ async def content_from_doc(doc: RAGDocument) -> str:
def make_overlapped_chunks(
document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any]
) -> list[Chunk]:
default_tokenizer = "DEFAULT_TIKTOKEN_TOKENIZER"
default_embedding_model = (
"DEFAULT_EMBEDDING_MODEL" # This will be correctly updated in `VectorDBWithIndex.insert_chunks`
)
tokenizer = Tokenizer.get_instance()
tokens = tokenizer.encode(text, bos=False, eos=False)
try:
@ -166,11 +172,25 @@ def make_overlapped_chunks(
chunk_metadata["token_count"] = len(toks)
chunk_metadata["metadata_token_count"] = len(metadata_tokens)
backend_chunk_metadata = ChunkMetadata(
document_id=document_id,
chunk_id=generate_chunk_id(chunk, text),
source=metadata.get("source", None),
created_timestamp=metadata.get("created_timestamp", int(time.time())),
updated_timestamp=int(time.time()),
chunk_window=f"{i}-{i + len(toks)}",
chunk_tokenizer=default_tokenizer,
chunk_embedding_model=default_embedding_model,
content_token_count=len(toks),
metadata_token_count=len(metadata_tokens),
)
# chunk is a string
chunks.append(
Chunk(
content=chunk,
metadata=chunk_metadata,
chunk_metadata=backend_chunk_metadata,
)
)
@ -235,9 +255,13 @@ class VectorDBWithIndex:
) -> None:
chunks_to_embed = []
for i, c in enumerate(chunks):
# this should be done in `make_overlapped_chunks` but we do it here for convenience
if c.embedding is None:
chunks_to_embed.append(c)
else:
if c.chunk_metadata:
c.chunk_metadata.chunk_embedding_model = self.vector_db.embedding_model
c.chunk_metadata.chunk_embedding_dimension = self.vector_db.embedding_dimension
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)
if chunks_to_embed:

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import hashlib
import logging
import uuid
from llama_stack.apis.vector_io import Chunk
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
hash_input = f"{document_id}:{chunk_text}".encode()
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
def extract_chunk_id_from_metadata(chunk: Chunk) -> str | None:
"""Extract existing chunk ID from metadata. This is for compatibility with older Chunks
that stored the document_id in the metadata and not in the ChunkMetadata."""
if chunk.chunk_metadata is not None and hasattr(chunk.chunk_metadata, "chunk_id"):
return chunk.chunk_metadata.chunk_id
if "chunk_id" in chunk.metadata:
return str(chunk.metadata["chunk_id"])
return None
def extract_or_generate_chunk_id(chunk: Chunk) -> str:
"""Extract existing chunk ID or generate a new one if not present. This is for compatibility with older Chunks
that stored the document_id in the metadata."""
stored_chunk_id = extract_chunk_id_from_metadata(chunk)
if stored_chunk_id:
return stored_chunk_id
elif "document_id" in chunk.metadata:
return generate_chunk_id(chunk.metadata["document_id"], str(chunk.content))
else:
logging.warning("Chunk has no ID or document_id in metadata. Generating random ID.")
return str(uuid.uuid4())