mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
swapping to configuring the entire chunk template
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
2e70782e63
commit
66f7b42795
7 changed files with 58 additions and 28 deletions
8
docs/_static/llama-stack-spec.html
vendored
8
docs/_static/llama-stack-spec.html
vendored
|
@ -11134,9 +11134,9 @@
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"default": 5
|
"default": 5
|
||||||
},
|
},
|
||||||
"include_metadata_in_content": {
|
"chunk_template": {
|
||||||
"type": "boolean",
|
"type": "string",
|
||||||
"default": false
|
"default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -11144,7 +11144,7 @@
|
||||||
"query_generator_config",
|
"query_generator_config",
|
||||||
"max_tokens_in_context",
|
"max_tokens_in_context",
|
||||||
"max_chunks",
|
"max_chunks",
|
||||||
"include_metadata_in_content"
|
"chunk_template"
|
||||||
],
|
],
|
||||||
"title": "RAGQueryConfig"
|
"title": "RAGQueryConfig"
|
||||||
},
|
},
|
||||||
|
|
13
docs/_static/llama-stack-spec.yaml
vendored
13
docs/_static/llama-stack-spec.yaml
vendored
|
@ -7690,15 +7690,20 @@ components:
|
||||||
max_chunks:
|
max_chunks:
|
||||||
type: integer
|
type: integer
|
||||||
default: 5
|
default: 5
|
||||||
include_metadata_in_content:
|
chunk_template:
|
||||||
type: boolean
|
type: string
|
||||||
default: false
|
default: >
|
||||||
|
Result {index}
|
||||||
|
|
||||||
|
Content: {chunk.content}
|
||||||
|
|
||||||
|
Metadata: {metadata}
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- query_generator_config
|
- query_generator_config
|
||||||
- max_tokens_in_context
|
- max_tokens_in_context
|
||||||
- max_chunks
|
- max_chunks
|
||||||
- include_metadata_in_content
|
- chunk_template
|
||||||
title: RAGQueryConfig
|
title: RAGQueryConfig
|
||||||
RAGQueryGeneratorConfig:
|
RAGQueryGeneratorConfig:
|
||||||
oneOf:
|
oneOf:
|
||||||
|
|
|
@ -99,14 +99,14 @@ results = client.tool_runtime.rag_tool.query(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
You can configure adding metadata to the context if you find it useful for your application. Simply add:
|
You can configure how the RAG tool adds metadata to the context if you find it useful for your application. Simply add:
|
||||||
```python
|
```python
|
||||||
# Query documents
|
# Query documents
|
||||||
results = client.tool_runtime.rag_tool.query(
|
results = client.tool_runtime.rag_tool.query(
|
||||||
vector_db_ids=[vector_db_id],
|
vector_db_ids=[vector_db_id],
|
||||||
content="What do you know about...",
|
content="What do you know about...",
|
||||||
query_config={
|
query_config={
|
||||||
"include_metadata_in_content": True,
|
"chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
@ -131,7 +131,7 @@ agent = Agent(
|
||||||
"query_config": {
|
"query_config": {
|
||||||
"chunk_size_in_tokens": 512,
|
"chunk_size_in_tokens": 512,
|
||||||
"chunk_overlap_in_tokens": 0,
|
"chunk_overlap_in_tokens": 0,
|
||||||
"include_metadata_in_content": False,
|
"chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from typing_extensions import Protocol, runtime_checkable
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||||
|
@ -72,7 +72,19 @@ class RAGQueryConfig(BaseModel):
|
||||||
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
|
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
|
||||||
max_tokens_in_context: int = 4096
|
max_tokens_in_context: int = 4096
|
||||||
max_chunks: int = 5
|
max_chunks: int = 5
|
||||||
include_metadata_in_content: bool = False
|
# Optional template for formatting each retrieved chunk in the context.
|
||||||
|
# Available placeholders: {index} (1-based chunk ordinal), {metadata} (chunk metadata dict), {chunk.content} (chunk content string).
|
||||||
|
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||||
|
|
||||||
|
@field_validator("chunk_template")
|
||||||
|
def validate_chunk_template(cls, v: str) -> str:
|
||||||
|
if "{chunk.content}" not in v:
|
||||||
|
raise ValueError("chunk_template must contain {chunk.content}")
|
||||||
|
if "{index}" not in v:
|
||||||
|
raise ValueError("chunk_template must contain {index}")
|
||||||
|
if len(v) == 0:
|
||||||
|
raise ValueError("chunk_template must not be empty")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|
|
@ -146,8 +146,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
metadata = chunk.metadata
|
metadata = chunk.metadata
|
||||||
tokens += metadata["token_count"]
|
tokens += metadata["token_count"]
|
||||||
if query_config.include_metadata_in_content:
|
tokens += metadata["metadata_token_count"]
|
||||||
tokens += metadata["metadata_token_count"]
|
|
||||||
|
|
||||||
if tokens > query_config.max_tokens_in_context:
|
if tokens > query_config.max_tokens_in_context:
|
||||||
log.error(
|
log.error(
|
||||||
|
@ -155,15 +154,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
text_content = f"Result {i + 1}:\n"
|
# text_content = f"Result {i + 1}:\n"
|
||||||
if query_config.include_metadata_in_content:
|
metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]}
|
||||||
metadata_subset = {
|
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset)
|
||||||
k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]
|
|
||||||
}
|
|
||||||
text_content += f"\nMetadata: {metadata_subset}"
|
|
||||||
else:
|
|
||||||
text_content += f"Document_id:{metadata['document_id'][:5]}"
|
|
||||||
text_content += f"\nContent: {chunk.content}\n"
|
|
||||||
picked.append(TextContentItem(text=text_content))
|
picked.append(TextContentItem(text=text_content))
|
||||||
|
|
||||||
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||||
|
|
|
@ -320,3 +320,6 @@ ignore_missing_imports = true
|
||||||
init_forbid_extra = true
|
init_forbid_extra = true
|
||||||
init_typed = true
|
init_typed = true
|
||||||
warn_required_dynamic_aliases = true
|
warn_required_dynamic_aliases = true
|
||||||
|
|
||||||
|
[tool.ruff.lint.pep8-naming]
|
||||||
|
classmethod-decorators = ["classmethod", "pydantic.field_validator"]
|
||||||
|
|
|
@ -141,7 +141,7 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_
|
||||||
document_id=f"num-{i}",
|
document_id=f"num-{i}",
|
||||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||||
mime_type="text/plain",
|
mime_type="text/plain",
|
||||||
metadata={"author": "llama", "source": url},
|
metadata={},
|
||||||
)
|
)
|
||||||
for i, url in enumerate(urls)
|
for i, url in enumerate(urls)
|
||||||
]
|
]
|
||||||
|
@ -205,12 +205,29 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client_with_empty_registry.tool_runtime.rag_tool.query(
|
response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||||
|
vector_db_ids=[vector_db_id],
|
||||||
|
content="What is the name of the method used for fine-tuning?",
|
||||||
|
)
|
||||||
|
assert_valid_text_response(response_with_metadata)
|
||||||
|
assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content)
|
||||||
|
|
||||||
|
response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||||
vector_db_ids=[vector_db_id],
|
vector_db_ids=[vector_db_id],
|
||||||
content="What is the name of the method used for fine-tuning?",
|
content="What is the name of the method used for fine-tuning?",
|
||||||
query_config={
|
query_config={
|
||||||
"include_metadata_in_content": True,
|
"include_metadata_in_content": True,
|
||||||
|
"chunk_template": "Result {index}\nContent: {chunk.content}\n",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert_valid_text_response(response)
|
assert_valid_text_response(response_without_metadata)
|
||||||
assert any("metadata:" in chunk.text.lower() for chunk in response.content)
|
assert not any("metadata:" in chunk.text.lower() for chunk in response_without_metadata.content)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||||
|
vector_db_ids=[vector_db_id],
|
||||||
|
content="What is the name of the method used for fine-tuning?",
|
||||||
|
query_config={
|
||||||
|
"chunk_template": "This should raise a ValueError because it is missing the proper template variables",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue