diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index f1bde880b..5df6db20c 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -11294,24 +11294,34 @@
"type": "object",
"properties": {
"query_generator_config": {
- "$ref": "#/components/schemas/RAGQueryGeneratorConfig"
+ "$ref": "#/components/schemas/RAGQueryGeneratorConfig",
+ "description": "Configuration for the query generator."
},
"max_tokens_in_context": {
"type": "integer",
- "default": 4096
+ "default": 4096,
+ "description": "Maximum number of tokens in the context."
},
"max_chunks": {
"type": "integer",
- "default": 5
+ "default": 5,
+ "description": "Maximum number of chunks to retrieve."
+ },
+ "chunk_template": {
+ "type": "string",
+ "default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
+ "description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
}
},
"additionalProperties": false,
"required": [
"query_generator_config",
"max_tokens_in_context",
- "max_chunks"
+ "max_chunks",
+ "chunk_template"
],
- "title": "RAGQueryConfig"
+ "title": "RAGQueryConfig",
+ "description": "Configuration for the RAG query generation."
},
"RAGQueryGeneratorConfig": {
"oneOf": [
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 10b5deec2..fb2dbf241 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -7794,18 +7794,37 @@ components:
properties:
query_generator_config:
$ref: '#/components/schemas/RAGQueryGeneratorConfig'
+ description: Configuration for the query generator.
max_tokens_in_context:
type: integer
default: 4096
+ description: Maximum number of tokens in the context.
max_chunks:
type: integer
default: 5
+ description: Maximum number of chunks to retrieve.
+ chunk_template:
+ type: string
+ default: >
+ Result {index}
+
+ Content: {chunk.content}
+
+ Metadata: {metadata}
+ description: >-
+ Template for formatting each retrieved chunk in the context. Available
+ placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
+ content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
+ {chunk.content}\nMetadata: {metadata}\n"
additionalProperties: false
required:
- query_generator_config
- max_tokens_in_context
- max_chunks
+ - chunk_template
title: RAGQueryConfig
+ description: >-
+ Configuration for the RAG query generation.
RAGQueryGeneratorConfig:
oneOf:
- $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md
index db6303209..dbe90a7fc 100644
--- a/docs/source/building_applications/rag.md
+++ b/docs/source/building_applications/rag.md
@@ -51,6 +51,7 @@ chunks = [
"mime_type": "text/plain",
"metadata": {
"document_id": "doc1",
+ "author": "Jane Doe",
},
},
]
@@ -98,6 +99,17 @@ results = client.tool_runtime.rag_tool.query(
)
```
+You can configure how the RAG tool adds metadata to the context if you find it useful for your application. Simply add:
+```python
+# Query documents
+results = client.tool_runtime.rag_tool.query(
+ vector_db_ids=[vector_db_id],
+ content="What do you know about...",
+ query_config={
+ "chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
+ },
+)
+```
### Building RAG-Enhanced Agents
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
@@ -115,6 +127,12 @@ agent = Agent(
"name": "builtin::rag/knowledge_search",
"args": {
"vector_db_ids": [vector_db_id],
+ # Defaults
+ "query_config": {
+ "chunk_size_in_tokens": 512,
+ "chunk_overlap_in_tokens": 0,
+ "chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
+ },
},
}
],
diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py
index fdf199b1a..de3e4c62c 100644
--- a/llama_stack/apis/tools/rag_tool.py
+++ b/llama_stack/apis/tools/rag_tool.py
@@ -7,7 +7,7 @@
from enum import Enum
from typing import Annotated, Any, Literal
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, Field, field_validator
from typing_extensions import Protocol, runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
@@ -67,11 +67,33 @@ register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
@json_schema_type
class RAGQueryConfig(BaseModel):
+ """
+ Configuration for the RAG query generation.
+
+ :param query_generator_config: Configuration for the query generator.
+ :param max_tokens_in_context: Maximum number of tokens in the context.
+ :param max_chunks: Maximum number of chunks to retrieve.
+ :param chunk_template: Template for formatting each retrieved chunk in the context.
+ Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
+ Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
+ """
+
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
max_tokens_in_context: int = 4096
max_chunks: int = 5
+ chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
+
+ @field_validator("chunk_template")
+ def validate_chunk_template(cls, v: str) -> str:
+ if "{chunk.content}" not in v:
+ raise ValueError("chunk_template must contain {chunk.content}")
+ if "{index}" not in v:
+ raise ValueError("chunk_template must contain {index}")
+ if len(v) == 0:
+ raise ValueError("chunk_template must not be empty")
+ return v
@runtime_checkable
diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py
index 968f93354..39f752297 100644
--- a/llama_stack/providers/inline/tool_runtime/rag/memory.py
+++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py
@@ -87,6 +87,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
content,
chunk_size_in_tokens,
chunk_size_in_tokens // 4,
+ doc.metadata,
)
)
@@ -142,19 +143,21 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
]
- for i, c in enumerate(chunks):
- metadata = c.metadata
+ for i, chunk in enumerate(chunks):
+ metadata = chunk.metadata
tokens += metadata["token_count"]
+ tokens += metadata["metadata_token_count"]
+
if tokens > query_config.max_tokens_in_context:
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
)
break
- picked.append(
- TextContentItem(
- text=f"Result {i + 1}:\nDocument_id:{metadata['document_id'][:5]}\nContent: {c.content}\n",
- )
- )
+
+ metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]}
+ text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset)
+ picked.append(TextContentItem(text=text_content))
+
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
picked.append(
TextContentItem(
diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py
index 9d892c166..e0e9d0679 100644
--- a/llama_stack/providers/utils/memory/vector_store.py
+++ b/llama_stack/providers/utils/memory/vector_store.py
@@ -139,22 +139,32 @@ async def content_from_doc(doc: RAGDocument) -> str:
return interleaved_content_as_str(doc.content)
-def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> list[Chunk]:
+def make_overlapped_chunks(
+ document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any]
+) -> list[Chunk]:
tokenizer = Tokenizer.get_instance()
tokens = tokenizer.encode(text, bos=False, eos=False)
+ try:
+ metadata_string = str(metadata)
+ except Exception as e:
+ raise ValueError("Failed to serialize metadata to string") from e
+
+ metadata_tokens = tokenizer.encode(metadata_string, bos=False, eos=False)
chunks = []
for i in range(0, len(tokens), window_len - overlap_len):
toks = tokens[i : i + window_len]
chunk = tokenizer.decode(toks)
+ chunk_metadata = metadata.copy()
+ chunk_metadata["document_id"] = document_id
+ chunk_metadata["token_count"] = len(toks)
+ chunk_metadata["metadata_token_count"] = len(metadata_tokens)
+
# chunk is a string
chunks.append(
Chunk(
content=chunk,
- metadata={
- "token_count": len(toks),
- "document_id": document_id,
- },
+ metadata=chunk_metadata,
)
)
diff --git a/pyproject.toml b/pyproject.toml
index ee180c4c9..f1bf7384f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -320,3 +320,6 @@ ignore_missing_imports = true
init_forbid_extra = true
init_typed = true
warn_required_dynamic_aliases = true
+
+[tool.ruff.lint.pep8-naming]
+classmethod-decorators = ["classmethod", "pydantic.field_validator"]
diff --git a/tests/integration/tool_runtime/test_rag_tool.py b/tests/integration/tool_runtime/test_rag_tool.py
index c49f507a8..2d049dc0c 100644
--- a/tests/integration/tool_runtime/test_rag_tool.py
+++ b/tests/integration/tool_runtime/test_rag_tool.py
@@ -49,7 +49,7 @@ def sample_documents():
]
-def assert_valid_response(response):
+def assert_valid_chunk_response(response):
assert len(response.chunks) > 0
assert len(response.scores) > 0
assert len(response.chunks) == len(response.scores)
@@ -57,6 +57,11 @@ def assert_valid_response(response):
assert isinstance(chunk.content, str)
+def assert_valid_text_response(response):
+ assert len(response.content) > 0
+ assert all(isinstance(chunk.text, str) for chunk in response.content)
+
+
def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_documents, embedding_model_id):
vector_db_id = "test_vector_db"
client_with_empty_registry.vector_dbs.register(
@@ -77,7 +82,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do
vector_db_id=vector_db_id,
query=query1,
)
- assert_valid_response(response1)
+ assert_valid_chunk_response(response1)
assert any("Python" in chunk.content for chunk in response1.chunks)
# Query with semantic similarity
@@ -86,7 +91,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do
vector_db_id=vector_db_id,
query=query2,
)
- assert_valid_response(response2)
+ assert_valid_chunk_response(response2)
assert any("neural networks" in chunk.content.lower() for chunk in response2.chunks)
# Query with limit on number of results (max_chunks=2)
@@ -96,7 +101,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do
query=query3,
params={"max_chunks": 2},
)
- assert_valid_response(response3)
+ assert_valid_chunk_response(response3)
assert len(response3.chunks) <= 2
# Query with threshold on similarity score
@@ -106,7 +111,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do
query=query4,
params={"score_threshold": 0.01},
)
- assert_valid_response(response4)
+ assert_valid_chunk_response(response4)
assert all(score >= 0.01 for score in response4.scores)
@@ -126,9 +131,6 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
assert vector_db_id in available_vector_dbs
- # URLs of documents to insert
- # TODO: Move to test/memory/resources then update the url to
- # https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/memory/resources/{url}
urls = [
"memory_optimizations.rst",
"chat.rst",
@@ -155,7 +157,7 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_
vector_db_id=vector_db_id,
query="What's the name of the fine-tunning method used?",
)
- assert_valid_response(response1)
+ assert_valid_chunk_response(response1)
assert any("lora" in chunk.content.lower() for chunk in response1.chunks)
# Query for the name of model
@@ -163,5 +165,69 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_
vector_db_id=vector_db_id,
query="Which Llama model is mentioned?",
)
- assert_valid_response(response2)
+ assert_valid_chunk_response(response2)
assert any("llama2" in chunk.content.lower() for chunk in response2.chunks)
+
+
+def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id):
+ providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
+ assert len(providers) > 0
+
+ vector_db_id = "test_vector_db"
+
+ client_with_empty_registry.vector_dbs.register(
+ vector_db_id=vector_db_id,
+ embedding_model=embedding_model_id,
+ embedding_dimension=384,
+ )
+
+ available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
+ assert vector_db_id in available_vector_dbs
+
+ urls = [
+ "memory_optimizations.rst",
+ "chat.rst",
+ "llama3.rst",
+ ]
+ documents = [
+ Document(
+ document_id=f"num-{i}",
+ content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
+ mime_type="text/plain",
+ metadata={"author": "llama", "source": url},
+ )
+ for i, url in enumerate(urls)
+ ]
+
+ client_with_empty_registry.tool_runtime.rag_tool.insert(
+ documents=documents,
+ vector_db_id=vector_db_id,
+ chunk_size_in_tokens=512,
+ )
+
+ response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
+ vector_db_ids=[vector_db_id],
+ content="What is the name of the method used for fine-tuning?",
+ )
+ assert_valid_text_response(response_with_metadata)
+ assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content)
+
+ response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
+ vector_db_ids=[vector_db_id],
+ content="What is the name of the method used for fine-tuning?",
+ query_config={
+ "include_metadata_in_content": True,
+ "chunk_template": "Result {index}\nContent: {chunk.content}\n",
+ },
+ )
+ assert_valid_text_response(response_without_metadata)
+ assert not any("metadata:" in chunk.text.lower() for chunk in response_without_metadata.content)
+
+ with pytest.raises(ValueError):
+ client_with_empty_registry.tool_runtime.rag_tool.query(
+ vector_db_ids=[vector_db_id],
+ content="What is the name of the method used for fine-tuning?",
+ query_config={
+ "chunk_template": "This should raise a ValueError because it is missing the proper template variables",
+ },
+ )
diff --git a/tests/unit/rag/test_vector_store.py b/tests/unit/rag/test_vector_store.py
index 3decc431e..f97808a6d 100644
--- a/tests/unit/rag/test_vector_store.py
+++ b/tests/unit/rag/test_vector_store.py
@@ -12,7 +12,7 @@ from pathlib import Path
import pytest
from llama_stack.apis.tools import RAGDocument
-from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc
+from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc, make_overlapped_chunks
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
# Depending on the machine, this can get parsed a couple of ways
@@ -76,3 +76,53 @@ class TestVectorStore:
)
content = await content_from_doc(doc)
assert content in DUMMY_PDF_TEXT_CHOICES
+
+ @pytest.mark.parametrize(
+ "window_len, overlap_len, expected_chunks",
+ [
+ (5, 2, 4), # Create 4 chunks with window of 5 and overlap of 2
+ (4, 1, 4), # Create 4 chunks with window of 4 and overlap of 1
+ ],
+ )
+ def test_make_overlapped_chunks(self, window_len, overlap_len, expected_chunks):
+ document_id = "test_doc_123"
+ text = "This is a sample document for testing the chunking behavior"
+ original_metadata = {"source": "test", "date": "2023-01-01", "author": "llama"}
+ len_metadata_tokens = 24 # specific to the metadata above
+
+ chunks = make_overlapped_chunks(document_id, text, window_len, overlap_len, original_metadata)
+
+ assert len(chunks) == expected_chunks
+
+ # Check that each chunk has the right metadata
+ for chunk in chunks:
+ # Original metadata should be preserved
+ assert chunk.metadata["source"] == "test"
+ assert chunk.metadata["date"] == "2023-01-01"
+ assert chunk.metadata["author"] == "llama"
+
+ # New metadata should be added
+ assert chunk.metadata["document_id"] == document_id
+ assert "token_count" in chunk.metadata
+ assert isinstance(chunk.metadata["token_count"], int)
+ assert chunk.metadata["token_count"] > 0
+ assert chunk.metadata["metadata_token_count"] == len_metadata_tokens
+
+ def test_raise_overlapped_chunks_metadata_serialization_error(self):
+ document_id = "test_doc_ex"
+ text = "Some text"
+ window_len = 5
+ overlap_len = 2
+
+ class BadMetadata:
+ def __repr__(self):
+ raise TypeError("Cannot convert to string")
+
+ problematic_metadata = {"bad_metadata_example": BadMetadata()}
+
+ with pytest.raises(ValueError) as excinfo:
+ make_overlapped_chunks(document_id, text, window_len, overlap_len, problematic_metadata)
+
+ assert str(excinfo.value) == "Failed to serialize metadata to string"
+ assert isinstance(excinfo.value.__cause__, TypeError)
+ assert str(excinfo.value.__cause__) == "Cannot convert to string"