swapping to configuring the entire chunk template

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-05-13 22:47:35 -04:00
parent 2e70782e63
commit 66f7b42795
7 changed files with 58 additions and 28 deletions

View file

@ -11134,9 +11134,9 @@
"type": "integer", "type": "integer",
"default": 5 "default": 5
}, },
"include_metadata_in_content": { "chunk_template": {
"type": "boolean", "type": "string",
"default": false "default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -11144,7 +11144,7 @@
"query_generator_config", "query_generator_config",
"max_tokens_in_context", "max_tokens_in_context",
"max_chunks", "max_chunks",
"include_metadata_in_content" "chunk_template"
], ],
"title": "RAGQueryConfig" "title": "RAGQueryConfig"
}, },

View file

@ -7690,15 +7690,20 @@ components:
max_chunks: max_chunks:
type: integer type: integer
default: 5 default: 5
include_metadata_in_content: chunk_template:
type: boolean type: string
default: false default: >
Result {index}
Content: {chunk.content}
Metadata: {metadata}
additionalProperties: false additionalProperties: false
required: required:
- query_generator_config - query_generator_config
- max_tokens_in_context - max_tokens_in_context
- max_chunks - max_chunks
- include_metadata_in_content - chunk_template
title: RAGQueryConfig title: RAGQueryConfig
RAGQueryGeneratorConfig: RAGQueryGeneratorConfig:
oneOf: oneOf:

View file

@ -99,14 +99,14 @@ results = client.tool_runtime.rag_tool.query(
) )
``` ```
You can configure adding metadata to the context if you find it useful for your application. Simply add: You can configure how the RAG tool adds metadata to the context if you find it useful for your application. Simply add:
```python ```python
# Query documents # Query documents
results = client.tool_runtime.rag_tool.query( results = client.tool_runtime.rag_tool.query(
vector_db_ids=[vector_db_id], vector_db_ids=[vector_db_id],
content="What do you know about...", content="What do you know about...",
query_config={ query_config={
"include_metadata_in_content": True, "chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
}, },
) )
``` ```
@ -131,7 +131,7 @@ agent = Agent(
"query_config": { "query_config": {
"chunk_size_in_tokens": 512, "chunk_size_in_tokens": 512,
"chunk_overlap_in_tokens": 0, "chunk_overlap_in_tokens": 0,
"include_metadata_in_content": False, "chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
}, },
}, },
} }

View file

@ -7,7 +7,7 @@
from enum import Enum from enum import Enum
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, field_validator
from typing_extensions import Protocol, runtime_checkable from typing_extensions import Protocol, runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.apis.common.content_types import URL, InterleavedContent
@ -72,7 +72,19 @@ class RAGQueryConfig(BaseModel):
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig()) query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
max_tokens_in_context: int = 4096 max_tokens_in_context: int = 4096
max_chunks: int = 5 max_chunks: int = 5
include_metadata_in_content: bool = False # Optional template for formatting each retrieved chunk in the context.
# Available placeholders: {index} (1-based chunk ordinal), {metadata} (chunk metadata dict), {chunk.content} (chunk content string).
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
@field_validator("chunk_template")
def validate_chunk_template(cls, v: str) -> str:
if "{chunk.content}" not in v:
raise ValueError("chunk_template must contain {chunk.content}")
if "{index}" not in v:
raise ValueError("chunk_template must contain {index}")
if len(v) == 0:
raise ValueError("chunk_template must not be empty")
return v
@runtime_checkable @runtime_checkable

View file

@ -146,8 +146,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
metadata = chunk.metadata metadata = chunk.metadata
tokens += metadata["token_count"] tokens += metadata["token_count"]
if query_config.include_metadata_in_content: tokens += metadata["metadata_token_count"]
tokens += metadata["metadata_token_count"]
if tokens > query_config.max_tokens_in_context: if tokens > query_config.max_tokens_in_context:
log.error( log.error(
@ -155,15 +154,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
) )
break break
text_content = f"Result {i + 1}:\n" # text_content = f"Result {i + 1}:\n"
if query_config.include_metadata_in_content: metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]}
metadata_subset = { text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset)
k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]
}
text_content += f"\nMetadata: {metadata_subset}"
else:
text_content += f"Document_id:{metadata['document_id'][:5]}"
text_content += f"\nContent: {chunk.content}\n"
picked.append(TextContentItem(text=text_content)) picked.append(TextContentItem(text=text_content))
picked.append(TextContentItem(text="END of knowledge_search tool results.\n")) picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))

View file

@ -320,3 +320,6 @@ ignore_missing_imports = true
init_forbid_extra = true init_forbid_extra = true
init_typed = true init_typed = true
warn_required_dynamic_aliases = true warn_required_dynamic_aliases = true
[tool.ruff.lint.pep8-naming]
classmethod-decorators = ["classmethod", "pydantic.field_validator"]

View file

@ -141,7 +141,7 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_
document_id=f"num-{i}", document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain", mime_type="text/plain",
metadata={"author": "llama", "source": url}, metadata={},
) )
for i, url in enumerate(urls) for i, url in enumerate(urls)
] ]
@ -205,12 +205,29 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
) )
response = client_with_empty_registry.tool_runtime.rag_tool.query( response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[vector_db_id],
content="What is the name of the method used for fine-tuning?",
)
assert_valid_text_response(response_with_metadata)
assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content)
response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[vector_db_id], vector_db_ids=[vector_db_id],
content="What is the name of the method used for fine-tuning?", content="What is the name of the method used for fine-tuning?",
query_config={ query_config={
"include_metadata_in_content": True, "include_metadata_in_content": True,
"chunk_template": "Result {index}\nContent: {chunk.content}\n",
}, },
) )
assert_valid_text_response(response) assert_valid_text_response(response_without_metadata)
assert any("metadata:" in chunk.text.lower() for chunk in response.content) assert not any("metadata:" in chunk.text.lower() for chunk in response_without_metadata.content)
with pytest.raises(ValueError):
client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[vector_db_id],
content="What is the name of the method used for fine-tuning?",
query_config={
"chunk_template": "This should raise a ValueError because it is missing the proper template variables",
},
)