From 8e7ab146f81ad4ce9cda71bfdcb1588f3fa1ec8d Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Wed, 14 May 2025 19:56:20 -0600 Subject: [PATCH] feat: Adding support for customizing chunk context in RAG insertion and querying (#2134) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? his PR allows users to customize the template used for chunks when inserted into the context. Additionally, this enables metadata injection into the context of an LLM for RAG. This makes a naive and crude assumption that each chunk should include the metadata, this is obviously redundant when multiple chunks are returned from the same document. In order to remove any sort of duplication of chunks, we'd have to make much more significant changes so this is a reasonable first step that unblocks users requesting this enhancement in https://github.com/meta-llama/llama-stack/issues/1767. In the future, this can be extended to support citations. List of Changes: - `llama_stack/apis/tools/rag_tool.py` - Added `chunk_template` field in `RAGQueryConfig`. - Added `field_validator` to validate the `chunk_template` field in `RAGQueryConfig`. - Ensured the `chunk_template` field includes placeholders `{index}` and `{chunk.content}`. - Updated the `query` method to use the `chunk_template` for formatting chunk text content. - `llama_stack/providers/inline/tool_runtime/rag/memory.py` - Modified the `insert` method to pass `doc.metadata` for chunk creation. - Enhanced the `query` method to format results using `chunk_template` and exclude unnecessary metadata fields like `token_count`. - `llama_stack/providers/utils/memory/vector_store.py` - Updated `make_overlapped_chunks` to include metadata serialization and token count for both content and metadata. - Added error handling for metadata serialization issues. - `pyproject.toml` - Added `pydantic.field_validator` as a recognized `classmethod` decorator in the linting configuration. - `tests/integration/tool_runtime/test_rag_tool.py` - Refactored test assertions to separate `assert_valid_chunk_response` and `assert_valid_text_response`. - Added integration tests to validate `chunk_template` functionality with and without metadata inclusion. - Included a test case to ensure `chunk_template` validation errors are raised appropriately. - `tests/unit/rag/test_vector_store.py` - Added unit tests for `make_overlapped_chunks`, verifying chunk creation with overlapping tokens and metadata integrity. - Added tests to handle metadata serialization errors, ensuring proper exception handling. - `docs/_static/llama-stack-spec.html` - Added a new `chunk_template` field of type `string` with a default template for formatting retrieved chunks in RAGQueryConfig. - Updated the `required` fields to include `chunk_template`. - `docs/_static/llama-stack-spec.yaml` - Introduced `chunk_template` field with a default value for RAGQueryConfig. - Updated the required configuration list to include `chunk_template`. - `docs/source/building_applications/rag.md` - Documented the `chunk_template` configuration, explaining how to customize metadata formatting in RAG queries. - Added examples demonstrating the usage of the `chunk_template` field in RAG tool queries. - Highlighted default values for `RAG` agent configurations. # Resolves https://github.com/meta-llama/llama-stack/issues/1767 ## Test Plan Updated both `test_vector_store.py` and `test_rag_tool.py` and tested end-to-end with a script. I also tested the quickstart to enable this and specified this metadata: ```python document = RAGDocument( document_id="document_1", content=source, mime_type="text/html", metadata={"author": "Paul Graham", "title": "How to do great work"}, ) ``` Which produced the output below: ![Screenshot 2025-05-13 at 10 53 43 PM](https://github.com/user-attachments/assets/bb199d04-501e-4217-9c44-4699d43d5519) This highlights the usefulness of the additional metadata. Notice how the metadata is redundant for different chunks of the same document. I think we can update that in a subsequent PR. # Documentation I've added a brief comment about this in the documentation to outline this to users and updated the API documentation. --------- Signed-off-by: Francisco Javier Arceo --- docs/_static/llama-stack-spec.html | 20 +++-- docs/_static/llama-stack-spec.yaml | 19 ++++ docs/source/building_applications/rag.md | 18 ++++ llama_stack/apis/tools/rag_tool.py | 24 +++++- .../inline/tool_runtime/rag/memory.py | 17 ++-- .../providers/utils/memory/vector_store.py | 20 +++-- pyproject.toml | 3 + .../integration/tool_runtime/test_rag_tool.py | 86 ++++++++++++++++--- tests/unit/rag/test_vector_store.py | 52 ++++++++++- 9 files changed, 230 insertions(+), 29 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index f1bde880b..5df6db20c 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -11294,24 +11294,34 @@ "type": "object", "properties": { "query_generator_config": { - "$ref": "#/components/schemas/RAGQueryGeneratorConfig" + "$ref": "#/components/schemas/RAGQueryGeneratorConfig", + "description": "Configuration for the query generator." }, "max_tokens_in_context": { "type": "integer", - "default": 4096 + "default": 4096, + "description": "Maximum number of tokens in the context." }, "max_chunks": { "type": "integer", - "default": 5 + "default": 5, + "description": "Maximum number of chunks to retrieve." + }, + "chunk_template": { + "type": "string", + "default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n", + "description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\"" } }, "additionalProperties": false, "required": [ "query_generator_config", "max_tokens_in_context", - "max_chunks" + "max_chunks", + "chunk_template" ], - "title": "RAGQueryConfig" + "title": "RAGQueryConfig", + "description": "Configuration for the RAG query generation." }, "RAGQueryGeneratorConfig": { "oneOf": [ diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 10b5deec2..fb2dbf241 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -7794,18 +7794,37 @@ components: properties: query_generator_config: $ref: '#/components/schemas/RAGQueryGeneratorConfig' + description: Configuration for the query generator. max_tokens_in_context: type: integer default: 4096 + description: Maximum number of tokens in the context. max_chunks: type: integer default: 5 + description: Maximum number of chunks to retrieve. + chunk_template: + type: string + default: > + Result {index} + + Content: {chunk.content} + + Metadata: {metadata} + description: >- + Template for formatting each retrieved chunk in the context. Available + placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk + content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent: + {chunk.content}\nMetadata: {metadata}\n" additionalProperties: false required: - query_generator_config - max_tokens_in_context - max_chunks + - chunk_template title: RAGQueryConfig + description: >- + Configuration for the RAG query generation. RAGQueryGeneratorConfig: oneOf: - $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig' diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index db6303209..dbe90a7fc 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -51,6 +51,7 @@ chunks = [ "mime_type": "text/plain", "metadata": { "document_id": "doc1", + "author": "Jane Doe", }, }, ] @@ -98,6 +99,17 @@ results = client.tool_runtime.rag_tool.query( ) ``` +You can configure how the RAG tool adds metadata to the context if you find it useful for your application. Simply add: +```python +# Query documents +results = client.tool_runtime.rag_tool.query( + vector_db_ids=[vector_db_id], + content="What do you know about...", + query_config={ + "chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n", + }, +) +``` ### Building RAG-Enhanced Agents One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example: @@ -115,6 +127,12 @@ agent = Agent( "name": "builtin::rag/knowledge_search", "args": { "vector_db_ids": [vector_db_id], + # Defaults + "query_config": { + "chunk_size_in_tokens": 512, + "chunk_overlap_in_tokens": 0, + "chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n", + }, }, } ], diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index fdf199b1a..de3e4c62c 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -7,7 +7,7 @@ from enum import Enum from typing import Annotated, Any, Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from typing_extensions import Protocol, runtime_checkable from llama_stack.apis.common.content_types import URL, InterleavedContent @@ -67,11 +67,33 @@ register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig") @json_schema_type class RAGQueryConfig(BaseModel): + """ + Configuration for the RAG query generation. + + :param query_generator_config: Configuration for the query generator. + :param max_tokens_in_context: Maximum number of tokens in the context. + :param max_chunks: Maximum number of chunks to retrieve. + :param chunk_template: Template for formatting each retrieved chunk in the context. + Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). + Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n" + """ + # This config defines how a query is generated using the messages # for memory bank retrieval. query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig()) max_tokens_in_context: int = 4096 max_chunks: int = 5 + chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" + + @field_validator("chunk_template") + def validate_chunk_template(cls, v: str) -> str: + if "{chunk.content}" not in v: + raise ValueError("chunk_template must contain {chunk.content}") + if "{index}" not in v: + raise ValueError("chunk_template must contain {index}") + if len(v) == 0: + raise ValueError("chunk_template must not be empty") + return v @runtime_checkable diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 968f93354..39f752297 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -87,6 +87,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): content, chunk_size_in_tokens, chunk_size_in_tokens // 4, + doc.metadata, ) ) @@ -142,19 +143,21 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n" ) ] - for i, c in enumerate(chunks): - metadata = c.metadata + for i, chunk in enumerate(chunks): + metadata = chunk.metadata tokens += metadata["token_count"] + tokens += metadata["metadata_token_count"] + if tokens > query_config.max_tokens_in_context: log.error( f"Using {len(picked)} chunks; reached max tokens in context: {tokens}", ) break - picked.append( - TextContentItem( - text=f"Result {i + 1}:\nDocument_id:{metadata['document_id'][:5]}\nContent: {c.content}\n", - ) - ) + + metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]} + text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset) + picked.append(TextContentItem(text=text_content)) + picked.append(TextContentItem(text="END of knowledge_search tool results.\n")) picked.append( TextContentItem( diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 9d892c166..e0e9d0679 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -139,22 +139,32 @@ async def content_from_doc(doc: RAGDocument) -> str: return interleaved_content_as_str(doc.content) -def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> list[Chunk]: +def make_overlapped_chunks( + document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any] +) -> list[Chunk]: tokenizer = Tokenizer.get_instance() tokens = tokenizer.encode(text, bos=False, eos=False) + try: + metadata_string = str(metadata) + except Exception as e: + raise ValueError("Failed to serialize metadata to string") from e + + metadata_tokens = tokenizer.encode(metadata_string, bos=False, eos=False) chunks = [] for i in range(0, len(tokens), window_len - overlap_len): toks = tokens[i : i + window_len] chunk = tokenizer.decode(toks) + chunk_metadata = metadata.copy() + chunk_metadata["document_id"] = document_id + chunk_metadata["token_count"] = len(toks) + chunk_metadata["metadata_token_count"] = len(metadata_tokens) + # chunk is a string chunks.append( Chunk( content=chunk, - metadata={ - "token_count": len(toks), - "document_id": document_id, - }, + metadata=chunk_metadata, ) ) diff --git a/pyproject.toml b/pyproject.toml index ee180c4c9..f1bf7384f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -320,3 +320,6 @@ ignore_missing_imports = true init_forbid_extra = true init_typed = true warn_required_dynamic_aliases = true + +[tool.ruff.lint.pep8-naming] +classmethod-decorators = ["classmethod", "pydantic.field_validator"] diff --git a/tests/integration/tool_runtime/test_rag_tool.py b/tests/integration/tool_runtime/test_rag_tool.py index c49f507a8..2d049dc0c 100644 --- a/tests/integration/tool_runtime/test_rag_tool.py +++ b/tests/integration/tool_runtime/test_rag_tool.py @@ -49,7 +49,7 @@ def sample_documents(): ] -def assert_valid_response(response): +def assert_valid_chunk_response(response): assert len(response.chunks) > 0 assert len(response.scores) > 0 assert len(response.chunks) == len(response.scores) @@ -57,6 +57,11 @@ def assert_valid_response(response): assert isinstance(chunk.content, str) +def assert_valid_text_response(response): + assert len(response.content) > 0 + assert all(isinstance(chunk.text, str) for chunk in response.content) + + def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_documents, embedding_model_id): vector_db_id = "test_vector_db" client_with_empty_registry.vector_dbs.register( @@ -77,7 +82,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do vector_db_id=vector_db_id, query=query1, ) - assert_valid_response(response1) + assert_valid_chunk_response(response1) assert any("Python" in chunk.content for chunk in response1.chunks) # Query with semantic similarity @@ -86,7 +91,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do vector_db_id=vector_db_id, query=query2, ) - assert_valid_response(response2) + assert_valid_chunk_response(response2) assert any("neural networks" in chunk.content.lower() for chunk in response2.chunks) # Query with limit on number of results (max_chunks=2) @@ -96,7 +101,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do query=query3, params={"max_chunks": 2}, ) - assert_valid_response(response3) + assert_valid_chunk_response(response3) assert len(response3.chunks) <= 2 # Query with threshold on similarity score @@ -106,7 +111,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do query=query4, params={"score_threshold": 0.01}, ) - assert_valid_response(response4) + assert_valid_chunk_response(response4) assert all(score >= 0.01 for score in response4.scores) @@ -126,9 +131,6 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_ available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] assert vector_db_id in available_vector_dbs - # URLs of documents to insert - # TODO: Move to test/memory/resources then update the url to - # https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/memory/resources/{url} urls = [ "memory_optimizations.rst", "chat.rst", @@ -155,7 +157,7 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_ vector_db_id=vector_db_id, query="What's the name of the fine-tunning method used?", ) - assert_valid_response(response1) + assert_valid_chunk_response(response1) assert any("lora" in chunk.content.lower() for chunk in response1.chunks) # Query for the name of model @@ -163,5 +165,69 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_ vector_db_id=vector_db_id, query="Which Llama model is mentioned?", ) - assert_valid_response(response2) + assert_valid_chunk_response(response2) assert any("llama2" in chunk.content.lower() for chunk in response2.chunks) + + +def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id): + providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"] + assert len(providers) > 0 + + vector_db_id = "test_vector_db" + + client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_id, + embedding_model=embedding_model_id, + embedding_dimension=384, + ) + + available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] + assert vector_db_id in available_vector_dbs + + urls = [ + "memory_optimizations.rst", + "chat.rst", + "llama3.rst", + ] + documents = [ + Document( + document_id=f"num-{i}", + content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", + mime_type="text/plain", + metadata={"author": "llama", "source": url}, + ) + for i, url in enumerate(urls) + ] + + client_with_empty_registry.tool_runtime.rag_tool.insert( + documents=documents, + vector_db_id=vector_db_id, + chunk_size_in_tokens=512, + ) + + response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query( + vector_db_ids=[vector_db_id], + content="What is the name of the method used for fine-tuning?", + ) + assert_valid_text_response(response_with_metadata) + assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content) + + response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query( + vector_db_ids=[vector_db_id], + content="What is the name of the method used for fine-tuning?", + query_config={ + "include_metadata_in_content": True, + "chunk_template": "Result {index}\nContent: {chunk.content}\n", + }, + ) + assert_valid_text_response(response_without_metadata) + assert not any("metadata:" in chunk.text.lower() for chunk in response_without_metadata.content) + + with pytest.raises(ValueError): + client_with_empty_registry.tool_runtime.rag_tool.query( + vector_db_ids=[vector_db_id], + content="What is the name of the method used for fine-tuning?", + query_config={ + "chunk_template": "This should raise a ValueError because it is missing the proper template variables", + }, + ) diff --git a/tests/unit/rag/test_vector_store.py b/tests/unit/rag/test_vector_store.py index 3decc431e..f97808a6d 100644 --- a/tests/unit/rag/test_vector_store.py +++ b/tests/unit/rag/test_vector_store.py @@ -12,7 +12,7 @@ from pathlib import Path import pytest from llama_stack.apis.tools import RAGDocument -from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc +from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc, make_overlapped_chunks DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf" # Depending on the machine, this can get parsed a couple of ways @@ -76,3 +76,53 @@ class TestVectorStore: ) content = await content_from_doc(doc) assert content in DUMMY_PDF_TEXT_CHOICES + + @pytest.mark.parametrize( + "window_len, overlap_len, expected_chunks", + [ + (5, 2, 4), # Create 4 chunks with window of 5 and overlap of 2 + (4, 1, 4), # Create 4 chunks with window of 4 and overlap of 1 + ], + ) + def test_make_overlapped_chunks(self, window_len, overlap_len, expected_chunks): + document_id = "test_doc_123" + text = "This is a sample document for testing the chunking behavior" + original_metadata = {"source": "test", "date": "2023-01-01", "author": "llama"} + len_metadata_tokens = 24 # specific to the metadata above + + chunks = make_overlapped_chunks(document_id, text, window_len, overlap_len, original_metadata) + + assert len(chunks) == expected_chunks + + # Check that each chunk has the right metadata + for chunk in chunks: + # Original metadata should be preserved + assert chunk.metadata["source"] == "test" + assert chunk.metadata["date"] == "2023-01-01" + assert chunk.metadata["author"] == "llama" + + # New metadata should be added + assert chunk.metadata["document_id"] == document_id + assert "token_count" in chunk.metadata + assert isinstance(chunk.metadata["token_count"], int) + assert chunk.metadata["token_count"] > 0 + assert chunk.metadata["metadata_token_count"] == len_metadata_tokens + + def test_raise_overlapped_chunks_metadata_serialization_error(self): + document_id = "test_doc_ex" + text = "Some text" + window_len = 5 + overlap_len = 2 + + class BadMetadata: + def __repr__(self): + raise TypeError("Cannot convert to string") + + problematic_metadata = {"bad_metadata_example": BadMetadata()} + + with pytest.raises(ValueError) as excinfo: + make_overlapped_chunks(document_id, text, window_len, overlap_len, problematic_metadata) + + assert str(excinfo.value) == "Failed to serialize metadata to string" + assert isinstance(excinfo.value.__cause__, TypeError) + assert str(excinfo.value.__cause__) == "Cannot convert to string"