diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 3e673aeed..60b4bd4f0 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -11134,9 +11134,9 @@
"type": "integer",
"default": 5
},
- "include_metadata_in_content": {
- "type": "boolean",
- "default": false
+ "chunk_template": {
+ "type": "string",
+ "default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
}
},
"additionalProperties": false,
@@ -11144,7 +11144,7 @@
"query_generator_config",
"max_tokens_in_context",
"max_chunks",
- "include_metadata_in_content"
+ "chunk_template"
],
"title": "RAGQueryConfig"
},
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 5fdabfb2e..7cb943939 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -7690,15 +7690,20 @@ components:
max_chunks:
type: integer
default: 5
- include_metadata_in_content:
- type: boolean
- default: false
+ chunk_template:
+ type: string
+ default: >
+ Result {index}
+
+ Content: {chunk.content}
+
+ Metadata: {metadata}
additionalProperties: false
required:
- query_generator_config
- max_tokens_in_context
- max_chunks
- - include_metadata_in_content
+ - chunk_template
title: RAGQueryConfig
RAGQueryGeneratorConfig:
oneOf:
diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md
index 663d25f72..dbe90a7fc 100644
--- a/docs/source/building_applications/rag.md
+++ b/docs/source/building_applications/rag.md
@@ -99,14 +99,14 @@ results = client.tool_runtime.rag_tool.query(
)
```
-You can configure adding metadata to the context if you find it useful for your application. Simply add:
+You can configure how the RAG tool adds metadata to the context if you find it useful for your application. Simply add:
```python
# Query documents
results = client.tool_runtime.rag_tool.query(
vector_db_ids=[vector_db_id],
content="What do you know about...",
query_config={
- "include_metadata_in_content": True,
+ "chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
},
)
```
@@ -131,7 +131,7 @@ agent = Agent(
"query_config": {
"chunk_size_in_tokens": 512,
"chunk_overlap_in_tokens": 0,
- "include_metadata_in_content": False,
+ "chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
},
},
}
diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py
index 8dc7b7385..6c94507be 100644
--- a/llama_stack/apis/tools/rag_tool.py
+++ b/llama_stack/apis/tools/rag_tool.py
@@ -7,7 +7,7 @@
from enum import Enum
from typing import Annotated, Any, Literal
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, Field, field_validator
from typing_extensions import Protocol, runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
@@ -72,7 +72,19 @@ class RAGQueryConfig(BaseModel):
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
max_tokens_in_context: int = 4096
max_chunks: int = 5
- include_metadata_in_content: bool = False
+ # Optional template for formatting each retrieved chunk in the context.
+ # Available placeholders: {index} (1-based chunk ordinal), {metadata} (chunk metadata dict), {chunk.content} (chunk content string).
+ chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
+
+ @field_validator("chunk_template")
+ def validate_chunk_template(cls, v: str) -> str:
+ if "{chunk.content}" not in v:
+ raise ValueError("chunk_template must contain {chunk.content}")
+ if "{index}" not in v:
+ raise ValueError("chunk_template must contain {index}")
+ if len(v) == 0:
+ raise ValueError("chunk_template must not be empty")
+ return v
@runtime_checkable
diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py
index b67bfcf4a..610e9fbaa 100644
--- a/llama_stack/providers/inline/tool_runtime/rag/memory.py
+++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py
@@ -146,8 +146,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
for i, chunk in enumerate(chunks):
metadata = chunk.metadata
tokens += metadata["token_count"]
- if query_config.include_metadata_in_content:
- tokens += metadata["metadata_token_count"]
+ tokens += metadata["metadata_token_count"]
if tokens > query_config.max_tokens_in_context:
log.error(
@@ -155,15 +154,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
)
break
- text_content = f"Result {i + 1}:\n"
- if query_config.include_metadata_in_content:
- metadata_subset = {
- k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]
- }
- text_content += f"\nMetadata: {metadata_subset}"
- else:
- text_content += f"Document_id:{metadata['document_id'][:5]}"
- text_content += f"\nContent: {chunk.content}\n"
+ # text_content = f"Result {i + 1}:\n"
+ metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]}
+ text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset)
picked.append(TextContentItem(text=text_content))
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
diff --git a/pyproject.toml b/pyproject.toml
index d3cc819be..125c127f1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -320,3 +320,6 @@ ignore_missing_imports = true
init_forbid_extra = true
init_typed = true
warn_required_dynamic_aliases = true
+
+[tool.ruff.lint.pep8-naming]
+classmethod-decorators = ["classmethod", "pydantic.field_validator"]
diff --git a/tests/integration/tool_runtime/test_rag_tool.py b/tests/integration/tool_runtime/test_rag_tool.py
index 3f13fa101..2d049dc0c 100644
--- a/tests/integration/tool_runtime/test_rag_tool.py
+++ b/tests/integration/tool_runtime/test_rag_tool.py
@@ -141,7 +141,7 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
- metadata={"author": "llama", "source": url},
+ metadata={},
)
for i, url in enumerate(urls)
]
@@ -205,12 +205,29 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
chunk_size_in_tokens=512,
)
- response = client_with_empty_registry.tool_runtime.rag_tool.query(
+ response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
+ vector_db_ids=[vector_db_id],
+ content="What is the name of the method used for fine-tuning?",
+ )
+ assert_valid_text_response(response_with_metadata)
+ assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content)
+
+ response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[vector_db_id],
content="What is the name of the method used for fine-tuning?",
query_config={
"include_metadata_in_content": True,
+ "chunk_template": "Result {index}\nContent: {chunk.content}\n",
},
)
- assert_valid_text_response(response)
- assert any("metadata:" in chunk.text.lower() for chunk in response.content)
+ assert_valid_text_response(response_without_metadata)
+ assert not any("metadata:" in chunk.text.lower() for chunk in response_without_metadata.content)
+
+ with pytest.raises(ValueError):
+ client_with_empty_registry.tool_runtime.rag_tool.query(
+ vector_db_ids=[vector_db_id],
+ content="What is the name of the method used for fine-tuning?",
+ query_config={
+ "chunk_template": "This should raise a ValueError because it is missing the proper template variables",
+ },
+ )