mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: Add ChunkMetadata to Chunk (#2497)
# What does this PR do? Adding `ChunkMetadata` so we can properly delete embeddings later. More specifically, this PR refactors and extends the chunk metadata handling in the vector database and introduces a distinction between metadata used for model context and backend-only metadata required for chunk management, storage, and retrieval. It also improves chunk ID generation and propagation throughout the stack, enhances test coverage, and adds new utility modules. ```python class ChunkMetadata(BaseModel): """ `ChunkMetadata` is backend metadata for a `Chunk` that is used to store additional information about the chunk that will NOT be inserted into the context during inference, but is required for backend functionality. Use `metadata` in `Chunk` for metadata that will be used during inference. """ document_id: str | None = None chunk_id: str | None = None source: str | None = None created_timestamp: int | None = None updated_timestamp: int | None = None chunk_window: str | None = None chunk_tokenizer: str | None = None chunk_embedding_model: str | None = None chunk_embedding_dimension: int | None = None content_token_count: int | None = None metadata_token_count: int | None = None ``` Eventually we can migrate the document_id out of the `metadata` field. I've introduced the changes so that `ChunkMetadata` is backwards compatible with `metadata`. <!-- If resolving an issue, uncomment and update the line below --> Closes https://github.com/meta-llama/llama-stack/issues/2501 ## Test Plan Added unit tests --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
fa0b0c13d4
commit
82f13fe83e
14 changed files with 490 additions and 218 deletions
|
@ -8,6 +8,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import uuid
|
||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -15,21 +16,80 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.providers.utils.vector_io.chunk_utils import generate_chunk_id
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack.strong_typing.schema import register_schema
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChunkMetadata(BaseModel):
|
||||
"""
|
||||
`ChunkMetadata` is backend metadata for a `Chunk` that is used to store additional information about the chunk that
|
||||
will not be used in the context during inference, but is required for backend functionality. The `ChunkMetadata`
|
||||
is set during chunk creation in `MemoryToolRuntimeImpl().insert()`and is not expected to change after.
|
||||
Use `Chunk.metadata` for metadata that will be used in the context during inference.
|
||||
:param chunk_id: The ID of the chunk. If not set, it will be generated based on the document ID and content.
|
||||
:param document_id: The ID of the document this chunk belongs to.
|
||||
:param source: The source of the content, such as a URL, file path, or other identifier.
|
||||
:param created_timestamp: An optional timestamp indicating when the chunk was created.
|
||||
:param updated_timestamp: An optional timestamp indicating when the chunk was last updated.
|
||||
:param chunk_window: The window of the chunk, which can be used to group related chunks together.
|
||||
:param chunk_tokenizer: The tokenizer used to create the chunk. Default is Tiktoken.
|
||||
:param chunk_embedding_model: The embedding model used to create the chunk's embedding.
|
||||
:param chunk_embedding_dimension: The dimension of the embedding vector for the chunk.
|
||||
:param content_token_count: The number of tokens in the content of the chunk.
|
||||
:param metadata_token_count: The number of tokens in the metadata of the chunk.
|
||||
"""
|
||||
|
||||
chunk_id: str | None = None
|
||||
document_id: str | None = None
|
||||
source: str | None = None
|
||||
created_timestamp: int | None = None
|
||||
updated_timestamp: int | None = None
|
||||
chunk_window: str | None = None
|
||||
chunk_tokenizer: str | None = None
|
||||
chunk_embedding_model: str | None = None
|
||||
chunk_embedding_dimension: int | None = None
|
||||
content_token_count: int | None = None
|
||||
metadata_token_count: int | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Chunk(BaseModel):
|
||||
"""
|
||||
A chunk of content that can be inserted into a vector database.
|
||||
:param content: The content of the chunk, which can be interleaved text, images, or other types.
|
||||
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
|
||||
:param metadata: Metadata associated with the chunk, such as document ID, source, or other relevant information.
|
||||
:param metadata: Metadata associated with the chunk that will be used in the model context during inference.
|
||||
:param stored_chunk_id: The chunk ID that is stored in the vector database. Used for backend functionality.
|
||||
:param chunk_metadata: Metadata for the chunk that will NOT be used in the context during inference.
|
||||
The `chunk_metadata` is required backend functionality.
|
||||
"""
|
||||
|
||||
content: InterleavedContent
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
embedding: list[float] | None = None
|
||||
# The alias parameter serializes the field as "chunk_id" in JSON but keeps the internal name as "stored_chunk_id"
|
||||
stored_chunk_id: str | None = Field(default=None, alias="chunk_id")
|
||||
chunk_metadata: ChunkMetadata | None = None
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
def model_post_init(self, __context):
|
||||
# Extract chunk_id from metadata if present
|
||||
if self.metadata and "chunk_id" in self.metadata:
|
||||
self.stored_chunk_id = self.metadata.pop("chunk_id")
|
||||
|
||||
@property
|
||||
def chunk_id(self) -> str:
|
||||
"""Returns the chunk ID, which is either an input `chunk_id` or a generated one if not set."""
|
||||
if self.stored_chunk_id:
|
||||
return self.stored_chunk_id
|
||||
|
||||
if "document_id" in self.metadata:
|
||||
return generate_chunk_id(self.metadata["document_id"], str(self.content))
|
||||
|
||||
return generate_chunk_id(str(uuid.uuid4()), str(self.content))
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -81,6 +81,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
chunks = []
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
# TODO: we should add enrichment here as URLs won't be added to the metadata by default
|
||||
chunks.extend(
|
||||
make_overlapped_chunks(
|
||||
doc.document_id,
|
||||
|
@ -157,8 +158,24 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
)
|
||||
break
|
||||
|
||||
metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]}
|
||||
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset)
|
||||
# Add useful keys from chunk_metadata to metadata and remove some from metadata
|
||||
chunk_metadata_keys_to_include_from_context = [
|
||||
"chunk_id",
|
||||
"document_id",
|
||||
"source",
|
||||
]
|
||||
metadata_keys_to_exclude_from_context = [
|
||||
"token_count",
|
||||
"metadata_token_count",
|
||||
]
|
||||
metadata_for_context = {}
|
||||
for k in chunk_metadata_keys_to_include_from_context:
|
||||
metadata_for_context[k] = getattr(chunk.chunk_metadata, k)
|
||||
for k in metadata:
|
||||
if k not in metadata_keys_to_exclude_from_context:
|
||||
metadata_for_context[k] = metadata[k]
|
||||
|
||||
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_for_context)
|
||||
picked.append(TextContentItem(text=text_content))
|
||||
|
||||
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||
|
|
|
@ -5,12 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import struct
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
@ -201,10 +199,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
batch_embeddings = embeddings[i : i + batch_size]
|
||||
|
||||
# Insert metadata
|
||||
metadata_data = [
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
||||
for chunk in batch_chunks
|
||||
]
|
||||
metadata_data = [(chunk.chunk_id, chunk.model_dump_json()) for chunk in batch_chunks]
|
||||
cur.executemany(
|
||||
f"""
|
||||
INSERT INTO {self.metadata_table} (id, chunk)
|
||||
|
@ -218,7 +213,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
embedding_data = [
|
||||
(
|
||||
(
|
||||
generate_chunk_id(chunk.metadata["document_id"], chunk.content),
|
||||
chunk.chunk_id,
|
||||
serialize_vector(emb.tolist()),
|
||||
)
|
||||
)
|
||||
|
@ -230,10 +225,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
)
|
||||
|
||||
# Insert FTS content
|
||||
fts_data = [
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.content)
|
||||
for chunk in batch_chunks
|
||||
]
|
||||
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
|
||||
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
|
||||
cur.executemany(
|
||||
f"DELETE FROM {self.fts_table} WHERE id = ?;",
|
||||
|
@ -381,13 +373,12 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
vector_response = await self.query_vector(embedding, k, score_threshold)
|
||||
keyword_response = await self.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
# Convert responses to score dictionaries using generate_chunk_id
|
||||
# Convert responses to score dictionaries using chunk_id
|
||||
vector_scores = {
|
||||
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
|
||||
for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
|
||||
chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
|
||||
}
|
||||
keyword_scores = {
|
||||
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
|
||||
chunk.chunk_id: score
|
||||
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
|
||||
}
|
||||
|
||||
|
@ -408,13 +399,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
|
||||
|
||||
# Create a map of chunk_id to chunk for both responses
|
||||
chunk_map = {}
|
||||
for c in vector_response.chunks:
|
||||
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
|
||||
chunk_map[chunk_id] = c
|
||||
for c in keyword_response.chunks:
|
||||
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
|
||||
chunk_map[chunk_id] = c
|
||||
chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
|
||||
|
||||
# Use the map to look up chunks by their IDs
|
||||
chunks = []
|
||||
|
@ -757,9 +742,3 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
if vector_db_id not in self.cache:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
return await self.cache[vector_db_id].query_chunks(query, params)
|
||||
|
||||
|
||||
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
||||
hash_input = f"{document_id}:{chunk_text}".encode()
|
||||
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
|
||||
|
|
|
@ -70,8 +70,8 @@ class QdrantIndex(EmbeddingIndex):
|
|||
)
|
||||
|
||||
points = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)):
|
||||
chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}"
|
||||
for _i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)):
|
||||
chunk_id = chunk.chunk_id
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=convert_id(chunk_id),
|
||||
|
|
|
@ -7,6 +7,7 @@ import base64
|
|||
import io
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
@ -23,12 +24,13 @@ from llama_stack.apis.common.content_types import (
|
|||
)
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.chunk_utils import generate_chunk_id
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -148,6 +150,7 @@ async def content_from_doc(doc: RAGDocument) -> str:
|
|||
def make_overlapped_chunks(
|
||||
document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any]
|
||||
) -> list[Chunk]:
|
||||
default_tokenizer = "DEFAULT_TIKTOKEN_TOKENIZER"
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
tokens = tokenizer.encode(text, bos=False, eos=False)
|
||||
try:
|
||||
|
@ -161,16 +164,32 @@ def make_overlapped_chunks(
|
|||
for i in range(0, len(tokens), window_len - overlap_len):
|
||||
toks = tokens[i : i + window_len]
|
||||
chunk = tokenizer.decode(toks)
|
||||
chunk_id = generate_chunk_id(chunk, text)
|
||||
chunk_metadata = metadata.copy()
|
||||
chunk_metadata["chunk_id"] = chunk_id
|
||||
chunk_metadata["document_id"] = document_id
|
||||
chunk_metadata["token_count"] = len(toks)
|
||||
chunk_metadata["metadata_token_count"] = len(metadata_tokens)
|
||||
|
||||
backend_chunk_metadata = ChunkMetadata(
|
||||
chunk_id=chunk_id,
|
||||
document_id=document_id,
|
||||
source=metadata.get("source", None),
|
||||
created_timestamp=metadata.get("created_timestamp", int(time.time())),
|
||||
updated_timestamp=int(time.time()),
|
||||
chunk_window=f"{i}-{i + len(toks)}",
|
||||
chunk_tokenizer=default_tokenizer,
|
||||
chunk_embedding_model=None, # This will be set in `VectorDBWithIndex.insert_chunks`
|
||||
content_token_count=len(toks),
|
||||
metadata_token_count=len(metadata_tokens),
|
||||
)
|
||||
|
||||
# chunk is a string
|
||||
chunks.append(
|
||||
Chunk(
|
||||
content=chunk,
|
||||
metadata=chunk_metadata,
|
||||
chunk_metadata=backend_chunk_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -237,6 +256,9 @@ class VectorDBWithIndex:
|
|||
for i, c in enumerate(chunks):
|
||||
if c.embedding is None:
|
||||
chunks_to_embed.append(c)
|
||||
if c.chunk_metadata:
|
||||
c.chunk_metadata.chunk_embedding_model = self.vector_db.embedding_model
|
||||
c.chunk_metadata.chunk_embedding_dimension = self.vector_db.embedding_dimension
|
||||
else:
|
||||
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)
|
||||
|
||||
|
|
5
llama_stack/providers/utils/vector_io/__init__.py
Normal file
5
llama_stack/providers/utils/vector_io/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
14
llama_stack/providers/utils/vector_io/chunk_utils.py
Normal file
14
llama_stack/providers/utils/vector_io/chunk_utils.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
|
||||
|
||||
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
||||
hash_input = f"{document_id}:{chunk_text}".encode()
|
||||
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
|
Loading…
Add table
Add a link
Reference in a new issue