diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index f1bde880b..5df6db20c 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -11294,24 +11294,34 @@ "type": "object", "properties": { "query_generator_config": { - "$ref": "#/components/schemas/RAGQueryGeneratorConfig" + "$ref": "#/components/schemas/RAGQueryGeneratorConfig", + "description": "Configuration for the query generator." }, "max_tokens_in_context": { "type": "integer", - "default": 4096 + "default": 4096, + "description": "Maximum number of tokens in the context." }, "max_chunks": { "type": "integer", - "default": 5 + "default": 5, + "description": "Maximum number of chunks to retrieve." + }, + "chunk_template": { + "type": "string", + "default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n", + "description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\"" } }, "additionalProperties": false, "required": [ "query_generator_config", "max_tokens_in_context", - "max_chunks" + "max_chunks", + "chunk_template" ], - "title": "RAGQueryConfig" + "title": "RAGQueryConfig", + "description": "Configuration for the RAG query generation." }, "RAGQueryGeneratorConfig": { "oneOf": [ diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 10b5deec2..fb2dbf241 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -7794,18 +7794,37 @@ components: properties: query_generator_config: $ref: '#/components/schemas/RAGQueryGeneratorConfig' + description: Configuration for the query generator. max_tokens_in_context: type: integer default: 4096 + description: Maximum number of tokens in the context. max_chunks: type: integer default: 5 + description: Maximum number of chunks to retrieve. + chunk_template: + type: string + default: > + Result {index} + + Content: {chunk.content} + + Metadata: {metadata} + description: >- + Template for formatting each retrieved chunk in the context. Available + placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk + content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent: + {chunk.content}\nMetadata: {metadata}\n" additionalProperties: false required: - query_generator_config - max_tokens_in_context - max_chunks + - chunk_template title: RAGQueryConfig + description: >- + Configuration for the RAG query generation. RAGQueryGeneratorConfig: oneOf: - $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig' diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index db6303209..dbe90a7fc 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -51,6 +51,7 @@ chunks = [ "mime_type": "text/plain", "metadata": { "document_id": "doc1", + "author": "Jane Doe", }, }, ] @@ -98,6 +99,17 @@ results = client.tool_runtime.rag_tool.query( ) ``` +You can configure how the RAG tool adds metadata to the context if you find it useful for your application. Simply add: +```python +# Query documents +results = client.tool_runtime.rag_tool.query( + vector_db_ids=[vector_db_id], + content="What do you know about...", + query_config={ + "chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n", + }, +) +``` ### Building RAG-Enhanced Agents One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example: @@ -115,6 +127,12 @@ agent = Agent( "name": "builtin::rag/knowledge_search", "args": { "vector_db_ids": [vector_db_id], + # Defaults + "query_config": { + "chunk_size_in_tokens": 512, + "chunk_overlap_in_tokens": 0, + "chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n", + }, }, } ], diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index fdf199b1a..de3e4c62c 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -7,7 +7,7 @@ from enum import Enum from typing import Annotated, Any, Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from typing_extensions import Protocol, runtime_checkable from llama_stack.apis.common.content_types import URL, InterleavedContent @@ -67,11 +67,33 @@ register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig") @json_schema_type class RAGQueryConfig(BaseModel): + """ + Configuration for the RAG query generation. + + :param query_generator_config: Configuration for the query generator. + :param max_tokens_in_context: Maximum number of tokens in the context. + :param max_chunks: Maximum number of chunks to retrieve. + :param chunk_template: Template for formatting each retrieved chunk in the context. + Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). + Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n" + """ + # This config defines how a query is generated using the messages # for memory bank retrieval. query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig()) max_tokens_in_context: int = 4096 max_chunks: int = 5 + chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" + + @field_validator("chunk_template") + def validate_chunk_template(cls, v: str) -> str: + if "{chunk.content}" not in v: + raise ValueError("chunk_template must contain {chunk.content}") + if "{index}" not in v: + raise ValueError("chunk_template must contain {index}") + if len(v) == 0: + raise ValueError("chunk_template must not be empty") + return v @runtime_checkable diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 968f93354..39f752297 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -87,6 +87,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): content, chunk_size_in_tokens, chunk_size_in_tokens // 4, + doc.metadata, ) ) @@ -142,19 +143,21 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n" ) ] - for i, c in enumerate(chunks): - metadata = c.metadata + for i, chunk in enumerate(chunks): + metadata = chunk.metadata tokens += metadata["token_count"] + tokens += metadata["metadata_token_count"] + if tokens > query_config.max_tokens_in_context: log.error( f"Using {len(picked)} chunks; reached max tokens in context: {tokens}", ) break - picked.append( - TextContentItem( - text=f"Result {i + 1}:\nDocument_id:{metadata['document_id'][:5]}\nContent: {c.content}\n", - ) - ) + + metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]} + text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset) + picked.append(TextContentItem(text=text_content)) + picked.append(TextContentItem(text="END of knowledge_search tool results.\n")) picked.append( TextContentItem( diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 9d892c166..e0e9d0679 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -139,22 +139,32 @@ async def content_from_doc(doc: RAGDocument) -> str: return interleaved_content_as_str(doc.content) -def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> list[Chunk]: +def make_overlapped_chunks( + document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any] +) -> list[Chunk]: tokenizer = Tokenizer.get_instance() tokens = tokenizer.encode(text, bos=False, eos=False) + try: + metadata_string = str(metadata) + except Exception as e: + raise ValueError("Failed to serialize metadata to string") from e + + metadata_tokens = tokenizer.encode(metadata_string, bos=False, eos=False) chunks = [] for i in range(0, len(tokens), window_len - overlap_len): toks = tokens[i : i + window_len] chunk = tokenizer.decode(toks) + chunk_metadata = metadata.copy() + chunk_metadata["document_id"] = document_id + chunk_metadata["token_count"] = len(toks) + chunk_metadata["metadata_token_count"] = len(metadata_tokens) + # chunk is a string chunks.append( Chunk( content=chunk, - metadata={ - "token_count": len(toks), - "document_id": document_id, - }, + metadata=chunk_metadata, ) ) diff --git a/pyproject.toml b/pyproject.toml index ee180c4c9..f1bf7384f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -320,3 +320,6 @@ ignore_missing_imports = true init_forbid_extra = true init_typed = true warn_required_dynamic_aliases = true + +[tool.ruff.lint.pep8-naming] +classmethod-decorators = ["classmethod", "pydantic.field_validator"] diff --git a/tests/integration/tool_runtime/test_rag_tool.py b/tests/integration/tool_runtime/test_rag_tool.py index c49f507a8..2d049dc0c 100644 --- a/tests/integration/tool_runtime/test_rag_tool.py +++ b/tests/integration/tool_runtime/test_rag_tool.py @@ -49,7 +49,7 @@ def sample_documents(): ] -def assert_valid_response(response): +def assert_valid_chunk_response(response): assert len(response.chunks) > 0 assert len(response.scores) > 0 assert len(response.chunks) == len(response.scores) @@ -57,6 +57,11 @@ def assert_valid_response(response): assert isinstance(chunk.content, str) +def assert_valid_text_response(response): + assert len(response.content) > 0 + assert all(isinstance(chunk.text, str) for chunk in response.content) + + def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_documents, embedding_model_id): vector_db_id = "test_vector_db" client_with_empty_registry.vector_dbs.register( @@ -77,7 +82,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do vector_db_id=vector_db_id, query=query1, ) - assert_valid_response(response1) + assert_valid_chunk_response(response1) assert any("Python" in chunk.content for chunk in response1.chunks) # Query with semantic similarity @@ -86,7 +91,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do vector_db_id=vector_db_id, query=query2, ) - assert_valid_response(response2) + assert_valid_chunk_response(response2) assert any("neural networks" in chunk.content.lower() for chunk in response2.chunks) # Query with limit on number of results (max_chunks=2) @@ -96,7 +101,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do query=query3, params={"max_chunks": 2}, ) - assert_valid_response(response3) + assert_valid_chunk_response(response3) assert len(response3.chunks) <= 2 # Query with threshold on similarity score @@ -106,7 +111,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do query=query4, params={"score_threshold": 0.01}, ) - assert_valid_response(response4) + assert_valid_chunk_response(response4) assert all(score >= 0.01 for score in response4.scores) @@ -126,9 +131,6 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_ available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] assert vector_db_id in available_vector_dbs - # URLs of documents to insert - # TODO: Move to test/memory/resources then update the url to - # https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/memory/resources/{url} urls = [ "memory_optimizations.rst", "chat.rst", @@ -155,7 +157,7 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_ vector_db_id=vector_db_id, query="What's the name of the fine-tunning method used?", ) - assert_valid_response(response1) + assert_valid_chunk_response(response1) assert any("lora" in chunk.content.lower() for chunk in response1.chunks) # Query for the name of model @@ -163,5 +165,69 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_ vector_db_id=vector_db_id, query="Which Llama model is mentioned?", ) - assert_valid_response(response2) + assert_valid_chunk_response(response2) assert any("llama2" in chunk.content.lower() for chunk in response2.chunks) + + +def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id): + providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"] + assert len(providers) > 0 + + vector_db_id = "test_vector_db" + + client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_id, + embedding_model=embedding_model_id, + embedding_dimension=384, + ) + + available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] + assert vector_db_id in available_vector_dbs + + urls = [ + "memory_optimizations.rst", + "chat.rst", + "llama3.rst", + ] + documents = [ + Document( + document_id=f"num-{i}", + content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", + mime_type="text/plain", + metadata={"author": "llama", "source": url}, + ) + for i, url in enumerate(urls) + ] + + client_with_empty_registry.tool_runtime.rag_tool.insert( + documents=documents, + vector_db_id=vector_db_id, + chunk_size_in_tokens=512, + ) + + response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query( + vector_db_ids=[vector_db_id], + content="What is the name of the method used for fine-tuning?", + ) + assert_valid_text_response(response_with_metadata) + assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content) + + response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query( + vector_db_ids=[vector_db_id], + content="What is the name of the method used for fine-tuning?", + query_config={ + "include_metadata_in_content": True, + "chunk_template": "Result {index}\nContent: {chunk.content}\n", + }, + ) + assert_valid_text_response(response_without_metadata) + assert not any("metadata:" in chunk.text.lower() for chunk in response_without_metadata.content) + + with pytest.raises(ValueError): + client_with_empty_registry.tool_runtime.rag_tool.query( + vector_db_ids=[vector_db_id], + content="What is the name of the method used for fine-tuning?", + query_config={ + "chunk_template": "This should raise a ValueError because it is missing the proper template variables", + }, + ) diff --git a/tests/unit/rag/test_vector_store.py b/tests/unit/rag/test_vector_store.py index 3decc431e..f97808a6d 100644 --- a/tests/unit/rag/test_vector_store.py +++ b/tests/unit/rag/test_vector_store.py @@ -12,7 +12,7 @@ from pathlib import Path import pytest from llama_stack.apis.tools import RAGDocument -from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc +from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc, make_overlapped_chunks DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf" # Depending on the machine, this can get parsed a couple of ways @@ -76,3 +76,53 @@ class TestVectorStore: ) content = await content_from_doc(doc) assert content in DUMMY_PDF_TEXT_CHOICES + + @pytest.mark.parametrize( + "window_len, overlap_len, expected_chunks", + [ + (5, 2, 4), # Create 4 chunks with window of 5 and overlap of 2 + (4, 1, 4), # Create 4 chunks with window of 4 and overlap of 1 + ], + ) + def test_make_overlapped_chunks(self, window_len, overlap_len, expected_chunks): + document_id = "test_doc_123" + text = "This is a sample document for testing the chunking behavior" + original_metadata = {"source": "test", "date": "2023-01-01", "author": "llama"} + len_metadata_tokens = 24 # specific to the metadata above + + chunks = make_overlapped_chunks(document_id, text, window_len, overlap_len, original_metadata) + + assert len(chunks) == expected_chunks + + # Check that each chunk has the right metadata + for chunk in chunks: + # Original metadata should be preserved + assert chunk.metadata["source"] == "test" + assert chunk.metadata["date"] == "2023-01-01" + assert chunk.metadata["author"] == "llama" + + # New metadata should be added + assert chunk.metadata["document_id"] == document_id + assert "token_count" in chunk.metadata + assert isinstance(chunk.metadata["token_count"], int) + assert chunk.metadata["token_count"] > 0 + assert chunk.metadata["metadata_token_count"] == len_metadata_tokens + + def test_raise_overlapped_chunks_metadata_serialization_error(self): + document_id = "test_doc_ex" + text = "Some text" + window_len = 5 + overlap_len = 2 + + class BadMetadata: + def __repr__(self): + raise TypeError("Cannot convert to string") + + problematic_metadata = {"bad_metadata_example": BadMetadata()} + + with pytest.raises(ValueError) as excinfo: + make_overlapped_chunks(document_id, text, window_len, overlap_len, problematic_metadata) + + assert str(excinfo.value) == "Failed to serialize metadata to string" + assert isinstance(excinfo.value.__cause__, TypeError) + assert str(excinfo.value.__cause__) == "Cannot convert to string"