diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 3e673aeed..60b4bd4f0 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -11134,9 +11134,9 @@ "type": "integer", "default": 5 }, - "include_metadata_in_content": { - "type": "boolean", - "default": false + "chunk_template": { + "type": "string", + "default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" } }, "additionalProperties": false, @@ -11144,7 +11144,7 @@ "query_generator_config", "max_tokens_in_context", "max_chunks", - "include_metadata_in_content" + "chunk_template" ], "title": "RAGQueryConfig" }, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 5fdabfb2e..7cb943939 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -7690,15 +7690,20 @@ components: max_chunks: type: integer default: 5 - include_metadata_in_content: - type: boolean - default: false + chunk_template: + type: string + default: > + Result {index} + + Content: {chunk.content} + + Metadata: {metadata} additionalProperties: false required: - query_generator_config - max_tokens_in_context - max_chunks - - include_metadata_in_content + - chunk_template title: RAGQueryConfig RAGQueryGeneratorConfig: oneOf: diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index 663d25f72..dbe90a7fc 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -99,14 +99,14 @@ results = client.tool_runtime.rag_tool.query( ) ``` -You can configure adding metadata to the context if you find it useful for your application. Simply add: +You can configure how the RAG tool adds metadata to the context if you find it useful for your application. Simply add: ```python # Query documents results = client.tool_runtime.rag_tool.query( vector_db_ids=[vector_db_id], content="What do you know about...", query_config={ - "include_metadata_in_content": True, + "chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n", }, ) ``` @@ -131,7 +131,7 @@ agent = Agent( "query_config": { "chunk_size_in_tokens": 512, "chunk_overlap_in_tokens": 0, - "include_metadata_in_content": False, + "chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n", }, }, } diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index 8dc7b7385..6c94507be 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -7,7 +7,7 @@ from enum import Enum from typing import Annotated, Any, Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from typing_extensions import Protocol, runtime_checkable from llama_stack.apis.common.content_types import URL, InterleavedContent @@ -72,7 +72,19 @@ class RAGQueryConfig(BaseModel): query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig()) max_tokens_in_context: int = 4096 max_chunks: int = 5 - include_metadata_in_content: bool = False + # Optional template for formatting each retrieved chunk in the context. + # Available placeholders: {index} (1-based chunk ordinal), {metadata} (chunk metadata dict), {chunk.content} (chunk content string). + chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" + + @field_validator("chunk_template") + def validate_chunk_template(cls, v: str) -> str: + if "{chunk.content}" not in v: + raise ValueError("chunk_template must contain {chunk.content}") + if "{index}" not in v: + raise ValueError("chunk_template must contain {index}") + if len(v) == 0: + raise ValueError("chunk_template must not be empty") + return v @runtime_checkable diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index b67bfcf4a..610e9fbaa 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -146,8 +146,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): for i, chunk in enumerate(chunks): metadata = chunk.metadata tokens += metadata["token_count"] - if query_config.include_metadata_in_content: - tokens += metadata["metadata_token_count"] + tokens += metadata["metadata_token_count"] if tokens > query_config.max_tokens_in_context: log.error( @@ -155,15 +154,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): ) break - text_content = f"Result {i + 1}:\n" - if query_config.include_metadata_in_content: - metadata_subset = { - k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"] - } - text_content += f"\nMetadata: {metadata_subset}" - else: - text_content += f"Document_id:{metadata['document_id'][:5]}" - text_content += f"\nContent: {chunk.content}\n" + # text_content = f"Result {i + 1}:\n" + metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]} + text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset) picked.append(TextContentItem(text=text_content)) picked.append(TextContentItem(text="END of knowledge_search tool results.\n")) diff --git a/pyproject.toml b/pyproject.toml index d3cc819be..125c127f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -320,3 +320,6 @@ ignore_missing_imports = true init_forbid_extra = true init_typed = true warn_required_dynamic_aliases = true + +[tool.ruff.lint.pep8-naming] +classmethod-decorators = ["classmethod", "pydantic.field_validator"] diff --git a/tests/integration/tool_runtime/test_rag_tool.py b/tests/integration/tool_runtime/test_rag_tool.py index 3f13fa101..2d049dc0c 100644 --- a/tests/integration/tool_runtime/test_rag_tool.py +++ b/tests/integration/tool_runtime/test_rag_tool.py @@ -141,7 +141,7 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_ document_id=f"num-{i}", content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", mime_type="text/plain", - metadata={"author": "llama", "source": url}, + metadata={}, ) for i, url in enumerate(urls) ] @@ -205,12 +205,29 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i chunk_size_in_tokens=512, ) - response = client_with_empty_registry.tool_runtime.rag_tool.query( + response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query( + vector_db_ids=[vector_db_id], + content="What is the name of the method used for fine-tuning?", + ) + assert_valid_text_response(response_with_metadata) + assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content) + + response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query( vector_db_ids=[vector_db_id], content="What is the name of the method used for fine-tuning?", query_config={ "include_metadata_in_content": True, + "chunk_template": "Result {index}\nContent: {chunk.content}\n", }, ) - assert_valid_text_response(response) - assert any("metadata:" in chunk.text.lower() for chunk in response.content) + assert_valid_text_response(response_without_metadata) + assert not any("metadata:" in chunk.text.lower() for chunk in response_without_metadata.content) + + with pytest.raises(ValueError): + client_with_empty_registry.tool_runtime.rag_tool.query( + vector_db_ids=[vector_db_id], + content="What is the name of the method used for fine-tuning?", + query_config={ + "chunk_template": "This should raise a ValueError because it is missing the proper template variables", + }, + )