diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 87e36d6c2..e4bf794e1 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -11221,7 +11221,7 @@ } ] }, - "description": "Metadata associated with the chunk that will be used during inference." + "description": "Metadata associated with the chunk that will be used in the model context during inference." }, "embedding": { "type": "array", @@ -11230,9 +11230,13 @@ }, "description": "Optional embedding for the chunk. If not provided, it will be computed later." }, + "stored_chunk_id": { + "type": "string", + "description": "The chunk ID that is stored in the vector database. Used for backend functionality." + }, "chunk_metadata": { "$ref": "#/components/schemas/ChunkMetadata", - "description": "Metadata for the chunk that will NOT be inserted into the context during inference that is required backend functionality." + "description": "Metadata for the chunk that will NOT be used in the context during inference. The `chunk_metadata` is required backend functionality." } }, "additionalProperties": false, @@ -11246,16 +11250,17 @@ "ChunkMetadata": { "type": "object", "properties": { + "chunk_id": { + "type": "string", + "description": "The ID of the chunk. If not set, it will be generated based on the document ID and content." + }, "document_id": { "type": "string", "description": "The ID of the document this chunk belongs to." }, - "chunk_id": { - "type": "string" - }, "source": { "type": "string", - "description": "The source of the content, such as a URL or file path." + "description": "The source of the content, such as a URL, file path, or other identifier." }, "created_timestamp": { "type": "integer", @@ -11291,8 +11296,11 @@ } }, "additionalProperties": false, + "required": [ + "chunk_id" + ], "title": "ChunkMetadata", - "description": "`ChunkMetadata` is backend metadata for a `Chunk` that is used to store additional information about the chunk that will NOT be inserted into the context during inference, but is required for backend functionality. Use `metadata` in `Chunk` for metadata that will be used during inference." + "description": "`ChunkMetadata` is backend metadata for a `Chunk` that is used to store additional information about the chunk that will not be used in the context during inference, but is required for backend functionality. The `ChunkMetadata` is set during chunk creation in `MemoryToolRuntimeImpl().insert()`and is not expected to change after. Use `Chunk.metadata` for metadata that will be used in the context during inference." }, "InsertChunksRequest": { "type": "object", diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 836ba479f..822462c09 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -7886,7 +7886,8 @@ components: - type: array - type: object description: >- - Metadata associated with the chunk that will be used during inference. + Metadata associated with the chunk that will be used in the model context + during inference. embedding: type: array items: @@ -7894,11 +7895,15 @@ components: description: >- Optional embedding for the chunk. If not provided, it will be computed later. + stored_chunk_id: + type: string + description: >- + The chunk ID that is stored in the vector database. Used for backend functionality. chunk_metadata: $ref: '#/components/schemas/ChunkMetadata' description: >- - Metadata for the chunk that will NOT be inserted into the context during - inference that is required backend functionality. + Metadata for the chunk that will NOT be used in the context during inference. + The `chunk_metadata` is required backend functionality. additionalProperties: false required: - content @@ -7909,16 +7914,19 @@ components: ChunkMetadata: type: object properties: + chunk_id: + type: string + description: >- + The ID of the chunk. If not set, it will be generated based on the document + ID and content. document_id: type: string description: >- The ID of the document this chunk belongs to. - chunk_id: - type: string source: type: string description: >- - The source of the content, such as a URL or file path. + The source of the content, such as a URL, file path, or other identifier. created_timestamp: type: integer description: >- @@ -7952,12 +7960,16 @@ components: description: >- The number of tokens in the metadata of the chunk. additionalProperties: false + required: + - chunk_id title: ChunkMetadata description: >- `ChunkMetadata` is backend metadata for a `Chunk` that is used to store additional - information about the chunk that will NOT be inserted into the context - during inference, but is required for backend functionality. Use `metadata` - in `Chunk` for metadata that will be used during inference. + information about the chunk that will not be used in the context during + inference, but is required for backend functionality. The `ChunkMetadata` is + set during chunk creation in `MemoryToolRuntimeImpl().insert()`and is not + expected to change after. Use `Chunk.metadata` for metadata that will + be used in the context during inference. InsertChunksRequest: type: object properties: diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index 8af66062a..6762c0d3f 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -8,6 +8,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import uuid from typing import Annotated, Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel, Field @@ -15,6 +16,7 @@ from pydantic import BaseModel, Field from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.providers.utils.vector_io.chunk_utils import generate_chunk_id from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.strong_typing.schema import register_schema @@ -23,10 +25,12 @@ from llama_stack.strong_typing.schema import register_schema class ChunkMetadata(BaseModel): """ `ChunkMetadata` is backend metadata for a `Chunk` that is used to store additional information about the chunk that - will NOT be inserted into the context during inference, but is required for backend functionality. - Use `metadata` in `Chunk` for metadata that will be used during inference. + will not be used in the context during inference, but is required for backend functionality. The `ChunkMetadata` + is set during chunk creation in `MemoryToolRuntimeImpl().insert()`and is not expected to change after. + Use `Chunk.metadata` for metadata that will be used in the context during inference. + :param chunk_id: The ID of the chunk. If not set, it will be generated based on the document ID and content. :param document_id: The ID of the document this chunk belongs to. - :param source: The source of the content, such as a URL or file path. + :param source: The source of the content, such as a URL, file path, or other identifier. :param created_timestamp: An optional timestamp indicating when the chunk was created. :param updated_timestamp: An optional timestamp indicating when the chunk was last updated. :param chunk_window: The window of the chunk, which can be used to group related chunks together. @@ -37,8 +41,8 @@ class ChunkMetadata(BaseModel): :param metadata_token_count: The number of tokens in the metadata of the chunk. """ + chunk_id: str = None document_id: str | None = None - chunk_id: str | None = None source: str | None = None created_timestamp: int | None = None updated_timestamp: int | None = None @@ -56,16 +60,37 @@ class Chunk(BaseModel): A chunk of content that can be inserted into a vector database. :param content: The content of the chunk, which can be interleaved text, images, or other types. :param embedding: Optional embedding for the chunk. If not provided, it will be computed later. - :param metadata: Metadata associated with the chunk that will be used during inference. - :param chunk_metadata: Metadata for the chunk that will NOT be inserted into the context during inference - that is required backend functionality. + :param metadata: Metadata associated with the chunk that will be used in the model context during inference. + :param stored_chunk_id: The chunk ID that is stored in the vector database. Used for backend functionality. + :param chunk_metadata: Metadata for the chunk that will NOT be used in the context during inference. + The `chunk_metadata` is required backend functionality. """ content: InterleavedContent metadata: dict[str, Any] = Field(default_factory=dict) embedding: list[float] | None = None + # The alias parameter serializes the field as "chunk_id" in JSON but keeps the internal name as "stored_chunk_id" + stored_chunk_id: str | None = Field(default=None, alias="chunk_id") chunk_metadata: ChunkMetadata | None = None + model_config = {"populate_by_name": True} + + def model_post_init(self, __context): + # Extract chunk_id from metadata if present + if self.metadata and "chunk_id" in self.metadata: + self.stored_chunk_id = self.metadata.pop("chunk_id") + + @property + def chunk_id(self) -> str: + """Returns the chunk ID, which is either an input `chunk_id` or a generated one if not set.""" + if self.stored_chunk_id: + return self.stored_chunk_id + + if "document_id" in self.metadata: + return generate_chunk_id(self.metadata["document_id"], str(self.content)) + + return generate_chunk_id(str(uuid.uuid4()), str(self.content)) + @json_schema_type class QueryChunksResponse(BaseModel): diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 084137db2..c0d80172e 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -81,6 +81,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti chunks = [] for doc in documents: content = await content_from_doc(doc) + # TODO: we should add enrichment here as URLs won't be added to the metadata by default chunks.extend( make_overlapped_chunks( doc.document_id, @@ -161,18 +162,19 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti break metadata_fields_to_exclude_from_context = [ - "chunk_tokenizer", - "chunk_window", - "token_count", - "metadata_token_count", - "chunk_tokenizer", - "chunk_embedding_model", "created_timestamp", "updated_timestamp", "chunk_window", + "chunk_tokenizer", + "chunk_embedding_model", + "chunk_embedding_dimension", + "token_count", "content_token_count", + "metadata_token_count", ] - metadata_subset = {k: v for k, v in metadata.items() if k not in metadata_fields_to_exclude_from_context} + metadata_subset = { + k: v for k, v in metadata.items() if k not in metadata_fields_to_exclude_from_context and v + } text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset) picked.append(TextContentItem(text=text_content)) diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 74d3f55ec..835bec90a 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -31,7 +31,6 @@ from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, ) -from llama_stack.providers.utils.vector_io.chunk_utils import extract_or_generate_chunk_id logger = logging.getLogger(__name__) @@ -200,9 +199,7 @@ class SQLiteVecIndex(EmbeddingIndex): batch_embeddings = embeddings[i : i + batch_size] # Insert metadata - metadata_data = [ - (extract_or_generate_chunk_id(chunk), chunk.model_dump_json()) for chunk in batch_chunks - ] + metadata_data = [(chunk.chunk_id, chunk.model_dump_json()) for chunk in batch_chunks] cur.executemany( f""" INSERT INTO {self.metadata_table} (id, chunk) @@ -216,7 +213,7 @@ class SQLiteVecIndex(EmbeddingIndex): embedding_data = [ ( ( - extract_or_generate_chunk_id(chunk), + chunk.chunk_id, serialize_vector(emb.tolist()), ) ) @@ -228,7 +225,7 @@ class SQLiteVecIndex(EmbeddingIndex): ) # Insert FTS content - fts_data = [(extract_or_generate_chunk_id(chunk), chunk.content) for chunk in batch_chunks] + fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks] # DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT) cur.executemany( f"DELETE FROM {self.fts_table} WHERE id = ?;", @@ -376,13 +373,12 @@ class SQLiteVecIndex(EmbeddingIndex): vector_response = await self.query_vector(embedding, k, score_threshold) keyword_response = await self.query_keyword(query_string, k, score_threshold) - # Convert responses to score dictionaries using generate_chunk_id + # Convert responses to score dictionaries using chunk_id vector_scores = { - extract_or_generate_chunk_id(chunk): score - for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False) + chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False) } keyword_scores = { - extract_or_generate_chunk_id(chunk): score + chunk.chunk_id: score for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False) } @@ -405,10 +401,10 @@ class SQLiteVecIndex(EmbeddingIndex): # Create a map of chunk_id to chunk for both responses chunk_map = {} for c in vector_response.chunks: - chunk_id = extract_or_generate_chunk_id(c) + chunk_id = c.chunk_id chunk_map[chunk_id] = c for c in keyword_response.chunks: - chunk_id = extract_or_generate_chunk_id(c) + chunk_id = c.chunk_id chunk_map[chunk_id] = c # Use the map to look up chunks by their IDs diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 3c7fa6430..ab204a75a 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -151,9 +151,6 @@ def make_overlapped_chunks( document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any] ) -> list[Chunk]: default_tokenizer = "DEFAULT_TIKTOKEN_TOKENIZER" - default_embedding_model = ( - "DEFAULT_EMBEDDING_MODEL" # This will be correctly updated in `VectorDBWithIndex.insert_chunks` - ) tokenizer = Tokenizer.get_instance() tokens = tokenizer.encode(text, bos=False, eos=False) try: @@ -167,20 +164,22 @@ def make_overlapped_chunks( for i in range(0, len(tokens), window_len - overlap_len): toks = tokens[i : i + window_len] chunk = tokenizer.decode(toks) + chunk_id = generate_chunk_id(chunk, text) chunk_metadata = metadata.copy() + chunk_metadata["chunk_id"] = chunk_id chunk_metadata["document_id"] = document_id chunk_metadata["token_count"] = len(toks) chunk_metadata["metadata_token_count"] = len(metadata_tokens) backend_chunk_metadata = ChunkMetadata( + chunk_id=chunk_id, document_id=document_id, - chunk_id=generate_chunk_id(chunk, text), source=metadata.get("source", None), created_timestamp=metadata.get("created_timestamp", int(time.time())), updated_timestamp=int(time.time()), chunk_window=f"{i}-{i + len(toks)}", chunk_tokenizer=default_tokenizer, - chunk_embedding_model=default_embedding_model, + chunk_embedding_model=None, # This will be set in `VectorDBWithIndex.insert_chunks` content_token_count=len(toks), metadata_token_count=len(metadata_tokens), ) @@ -255,13 +254,12 @@ class VectorDBWithIndex: ) -> None: chunks_to_embed = [] for i, c in enumerate(chunks): - # this should be done in `make_overlapped_chunks` but we do it here for convenience if c.embedding is None: chunks_to_embed.append(c) - else: if c.chunk_metadata: c.chunk_metadata.chunk_embedding_model = self.vector_db.embedding_model c.chunk_metadata.chunk_embedding_dimension = self.vector_db.embedding_dimension + else: _validate_embedding(c.embedding, i, self.vector_db.embedding_dimension) if chunks_to_embed: diff --git a/llama_stack/providers/utils/vector_io/chunk_utils.py b/llama_stack/providers/utils/vector_io/chunk_utils.py index 1169eb0cd..68cf11cad 100644 --- a/llama_stack/providers/utils/vector_io/chunk_utils.py +++ b/llama_stack/providers/utils/vector_io/chunk_utils.py @@ -5,38 +5,10 @@ # the root directory of this source tree. import hashlib -import logging import uuid -from llama_stack.apis.vector_io import Chunk - def generate_chunk_id(document_id: str, chunk_text: str) -> str: """Generate a unique chunk ID using a hash of document ID and chunk text.""" hash_input = f"{document_id}:{chunk_text}".encode() return str(uuid.UUID(hashlib.md5(hash_input).hexdigest())) - - -def extract_chunk_id_from_metadata(chunk: Chunk) -> str | None: - """Extract existing chunk ID from metadata. This is for compatibility with older Chunks - that stored the document_id in the metadata and not in the ChunkMetadata.""" - if chunk.chunk_metadata is not None and hasattr(chunk.chunk_metadata, "chunk_id"): - return chunk.chunk_metadata.chunk_id - - if "chunk_id" in chunk.metadata: - return str(chunk.metadata["chunk_id"]) - - return None - - -def extract_or_generate_chunk_id(chunk: Chunk) -> str: - """Extract existing chunk ID or generate a new one if not present. This is for compatibility with older Chunks - that stored the document_id in the metadata.""" - stored_chunk_id = extract_chunk_id_from_metadata(chunk) - if stored_chunk_id: - return stored_chunk_id - elif "document_id" in chunk.metadata: - return generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)) - else: - logging.warning("Chunk has no ID or document_id in metadata. Generating random ID.") - return str(uuid.uuid4()) diff --git a/tests/unit/providers/vector_io/test_chunk_utils.py b/tests/unit/providers/vector_io/test_chunk_utils.py index 1549ddd4f..941928b6d 100644 --- a/tests/unit/providers/vector_io/test_chunk_utils.py +++ b/tests/unit/providers/vector_io/test_chunk_utils.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from llama_stack.apis.vector_io import Chunk, ChunkMetadata -from llama_stack.providers.utils.vector_io.chunk_utils import extract_or_generate_chunk_id, generate_chunk_id +from llama_stack.providers.utils.vector_io.chunk_utils import generate_chunk_id # This test is a unit test for the chunk_utils.py helpers. This should only contain # tests which are specific to this file. More general (API-level) tests should be placed in @@ -24,7 +24,7 @@ def test_generate_chunk_id(): Chunk(content="test 3", metadata={"document_id": "doc-1"}), ] - chunk_ids = sorted([generate_chunk_id(chunk.metadata["document_id"], chunk.content) for chunk in chunks]) + chunk_ids = sorted([chunk.chunk_id for chunk in chunks]) assert chunk_ids == [ "177a1368-f6a8-0c50-6e92-18677f2c3de3", "bc744db3-1b25-0a9c-cdff-b6ba3df73c36", @@ -32,22 +32,35 @@ def test_generate_chunk_id(): ] -def test_extract_or_generate_chunk_id(): +def test_chunk_id(): # Test with existing chunk ID chunk_with_id = Chunk(content="test", metadata={"document_id": "existing-id"}) - assert extract_or_generate_chunk_id(chunk_with_id) == "84ededcc-b80b-a83e-1a20-ca6515a11350" + assert chunk_with_id.chunk_id == "84ededcc-b80b-a83e-1a20-ca6515a11350" # Test with document ID in metadata chunk_with_doc_id = Chunk(content="test", metadata={"document_id": "doc-1"}) - assert extract_or_generate_chunk_id(chunk_with_doc_id) == generate_chunk_id("doc-1", "test") + assert chunk_with_doc_id.chunk_id == generate_chunk_id("doc-1", "test") # Test chunks with ChunkMetadata chunk_with_metadata = Chunk( - content="test", metadata={"document_id": "existing-id"}, chunk_metadata=ChunkMetadata(chunk_id="chunk-id-1") + content="test", + metadata={"document_id": "existing-id", "chunk_id": "chunk-id-1"}, + chunk_metadata=ChunkMetadata(document_id="document_1"), ) - assert extract_or_generate_chunk_id(chunk_with_metadata) == "chunk-id-1" + assert chunk_with_metadata.chunk_id == "chunk-id-1" # Test with no ID or document ID chunk_without_id = Chunk(content="test") - generated_id = extract_or_generate_chunk_id(chunk_without_id) + generated_id = chunk_without_id.chunk_id assert isinstance(generated_id, str) and len(generated_id) == 36 # Should be a valid UUID + + +def test_stored_chunk_id_alias(): + # Test with existing chunk ID alias + chunk_with_alias = Chunk(content="test", metadata={"document_id": "existing-id", "chunk_id": "chunk-id-1"}) + assert chunk_with_alias.chunk_id == "chunk-id-1" + serialized_chunk = chunk_with_alias.model_dump() + assert serialized_chunk["stored_chunk_id"] == "chunk-id-1" + # showing chunk_id is not serialized (i.e., a computed field) + assert "chunk_id" not in serialized_chunk + assert chunk_with_alias.stored_chunk_id == "chunk-id-1" diff --git a/tests/unit/providers/vector_io/test_sqlite_vec.py b/tests/unit/providers/vector_io/test_sqlite_vec.py index 8fa7d3cba..bbac717c7 100644 --- a/tests/unit/providers/vector_io/test_sqlite_vec.py +++ b/tests/unit/providers/vector_io/test_sqlite_vec.py @@ -64,6 +64,14 @@ async def test_query_chunks_vector(sqlite_vec_index, sample_chunks, sample_embed assert len(response.chunks) == 2 +@pytest.mark.xfail(reason="Chunk Metadata not yet supported for SQLite-vec", strict=True) +async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks, sample_embeddings): + await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) + query_embedding = sample_embeddings[0] + response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0) + assert response.chunks[-1].chunk_metadata == sample_chunks[-1].chunk_metadata + + @pytest.mark.asyncio async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings): await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py index b9fd8cca4..9a24cff1b 100644 --- a/tests/unit/rag/test_rag_query.py +++ b/tests/unit/rag/test_rag_query.py @@ -4,10 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest +from llama_stack.apis.vector_io import ( + Chunk, + ChunkMetadata, + QueryChunksResponse, +) from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl @@ -17,3 +22,41 @@ class TestRagQuery: rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock()) with pytest.raises(ValueError): await rag_tool.query(content=MagicMock(), vector_db_ids=[]) + + @pytest.mark.asyncio + async def test_query_chunk_metadata_handling(self): + rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock()) + content = "test query content" + vector_db_ids = ["db1"] + + chunk_metadata = ChunkMetadata( + document_id="doc1", + chunk_id="chunk1", + source="test_source", + metadata_token_count=5, + ) + interleaved_content = MagicMock() + chunk = Chunk( + content=interleaved_content, + metadata={ + "key1": "value1", + "token_count": 10, + "metadata_token_count": 5, + # Note this is inserted into `metadata` during MemoryToolRuntimeImpl().insert() + "document_id": "doc1", + }, + stored_chunk_id="chunk1", + chunk_metadata=chunk_metadata, + ) + + query_response = QueryChunksResponse(chunks=[chunk], scores=[1.0]) + + rag_tool.vector_io_api.query_chunks = AsyncMock(return_value=query_response) + result = await rag_tool.query(content=content, vector_db_ids=vector_db_ids) + + assert result is not None + expected_metadata_string = ( + "Metadata: {'key1': 'value1', 'document_id': 'doc1', 'chunk_id': 'chunk1', 'source': 'test_source'}" + ) + assert expected_metadata_string in result.content[1].text + assert result.content is not None