mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
swapping to configuring the entire chunk template
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
2e70782e63
commit
66f7b42795
7 changed files with 58 additions and 28 deletions
8
docs/_static/llama-stack-spec.html
vendored
8
docs/_static/llama-stack-spec.html
vendored
|
@ -11134,9 +11134,9 @@
|
|||
"type": "integer",
|
||||
"default": 5
|
||||
},
|
||||
"include_metadata_in_content": {
|
||||
"type": "boolean",
|
||||
"default": false
|
||||
"chunk_template": {
|
||||
"type": "string",
|
||||
"default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
@ -11144,7 +11144,7 @@
|
|||
"query_generator_config",
|
||||
"max_tokens_in_context",
|
||||
"max_chunks",
|
||||
"include_metadata_in_content"
|
||||
"chunk_template"
|
||||
],
|
||||
"title": "RAGQueryConfig"
|
||||
},
|
||||
|
|
13
docs/_static/llama-stack-spec.yaml
vendored
13
docs/_static/llama-stack-spec.yaml
vendored
|
@ -7690,15 +7690,20 @@ components:
|
|||
max_chunks:
|
||||
type: integer
|
||||
default: 5
|
||||
include_metadata_in_content:
|
||||
type: boolean
|
||||
default: false
|
||||
chunk_template:
|
||||
type: string
|
||||
default: >
|
||||
Result {index}
|
||||
|
||||
Content: {chunk.content}
|
||||
|
||||
Metadata: {metadata}
|
||||
additionalProperties: false
|
||||
required:
|
||||
- query_generator_config
|
||||
- max_tokens_in_context
|
||||
- max_chunks
|
||||
- include_metadata_in_content
|
||||
- chunk_template
|
||||
title: RAGQueryConfig
|
||||
RAGQueryGeneratorConfig:
|
||||
oneOf:
|
||||
|
|
|
@ -99,14 +99,14 @@ results = client.tool_runtime.rag_tool.query(
|
|||
)
|
||||
```
|
||||
|
||||
You can configure adding metadata to the context if you find it useful for your application. Simply add:
|
||||
You can configure how the RAG tool adds metadata to the context if you find it useful for your application. Simply add:
|
||||
```python
|
||||
# Query documents
|
||||
results = client.tool_runtime.rag_tool.query(
|
||||
vector_db_ids=[vector_db_id],
|
||||
content="What do you know about...",
|
||||
query_config={
|
||||
"include_metadata_in_content": True,
|
||||
"chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
|
||||
},
|
||||
)
|
||||
```
|
||||
|
@ -131,7 +131,7 @@ agent = Agent(
|
|||
"query_config": {
|
||||
"chunk_size_in_tokens": 512,
|
||||
"chunk_overlap_in_tokens": 0,
|
||||
"include_metadata_in_content": False,
|
||||
"chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
from enum import Enum
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
|
@ -72,7 +72,19 @@ class RAGQueryConfig(BaseModel):
|
|||
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 5
|
||||
include_metadata_in_content: bool = False
|
||||
# Optional template for formatting each retrieved chunk in the context.
|
||||
# Available placeholders: {index} (1-based chunk ordinal), {metadata} (chunk metadata dict), {chunk.content} (chunk content string).
|
||||
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||
|
||||
@field_validator("chunk_template")
|
||||
def validate_chunk_template(cls, v: str) -> str:
|
||||
if "{chunk.content}" not in v:
|
||||
raise ValueError("chunk_template must contain {chunk.content}")
|
||||
if "{index}" not in v:
|
||||
raise ValueError("chunk_template must contain {index}")
|
||||
if len(v) == 0:
|
||||
raise ValueError("chunk_template must not be empty")
|
||||
return v
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
|
@ -146,8 +146,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
for i, chunk in enumerate(chunks):
|
||||
metadata = chunk.metadata
|
||||
tokens += metadata["token_count"]
|
||||
if query_config.include_metadata_in_content:
|
||||
tokens += metadata["metadata_token_count"]
|
||||
tokens += metadata["metadata_token_count"]
|
||||
|
||||
if tokens > query_config.max_tokens_in_context:
|
||||
log.error(
|
||||
|
@ -155,15 +154,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
)
|
||||
break
|
||||
|
||||
text_content = f"Result {i + 1}:\n"
|
||||
if query_config.include_metadata_in_content:
|
||||
metadata_subset = {
|
||||
k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]
|
||||
}
|
||||
text_content += f"\nMetadata: {metadata_subset}"
|
||||
else:
|
||||
text_content += f"Document_id:{metadata['document_id'][:5]}"
|
||||
text_content += f"\nContent: {chunk.content}\n"
|
||||
# text_content = f"Result {i + 1}:\n"
|
||||
metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]}
|
||||
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset)
|
||||
picked.append(TextContentItem(text=text_content))
|
||||
|
||||
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||
|
|
|
@ -320,3 +320,6 @@ ignore_missing_imports = true
|
|||
init_forbid_extra = true
|
||||
init_typed = true
|
||||
warn_required_dynamic_aliases = true
|
||||
|
||||
[tool.ruff.lint.pep8-naming]
|
||||
classmethod-decorators = ["classmethod", "pydantic.field_validator"]
|
||||
|
|
|
@ -141,7 +141,7 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_
|
|||
document_id=f"num-{i}",
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
metadata={"author": "llama", "source": url},
|
||||
metadata={},
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
@ -205,12 +205,29 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
|
|||
chunk_size_in_tokens=512,
|
||||
)
|
||||
|
||||
response = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||
response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||
vector_db_ids=[vector_db_id],
|
||||
content="What is the name of the method used for fine-tuning?",
|
||||
)
|
||||
assert_valid_text_response(response_with_metadata)
|
||||
assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content)
|
||||
|
||||
response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||
vector_db_ids=[vector_db_id],
|
||||
content="What is the name of the method used for fine-tuning?",
|
||||
query_config={
|
||||
"include_metadata_in_content": True,
|
||||
"chunk_template": "Result {index}\nContent: {chunk.content}\n",
|
||||
},
|
||||
)
|
||||
assert_valid_text_response(response)
|
||||
assert any("metadata:" in chunk.text.lower() for chunk in response.content)
|
||||
assert_valid_text_response(response_without_metadata)
|
||||
assert not any("metadata:" in chunk.text.lower() for chunk in response_without_metadata.content)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||
vector_db_ids=[vector_db_id],
|
||||
content="What is the name of the method used for fine-tuning?",
|
||||
query_config={
|
||||
"chunk_template": "This should raise a ValueError because it is missing the proper template variables",
|
||||
},
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue