From e50a546bc0315aa6780f622773b74dfe8bf9da11 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Fri, 9 May 2025 23:38:47 -0400 Subject: [PATCH] feat: Adding support for metadata in RAG insertion and querying Signed-off-by: Francisco Javier Arceo --- docs/_static/llama-stack-spec.html | 7 +- docs/_static/llama-stack-spec.yaml | 4 ++ docs/source/building_applications/rag.md | 18 +++++ llama_stack/apis/tools/rag_tool.py | 1 + .../inline/tool_runtime/rag/memory.py | 25 +++++-- .../providers/utils/memory/vector_store.py | 15 ++-- .../integration/tool_runtime/test_rag_tool.py | 71 ++++++++++++++++--- tests/unit/rag/test_vector_store.py | 33 ++++++++- 8 files changed, 149 insertions(+), 25 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 4020dc4cd..3e673aeed 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -11133,13 +11133,18 @@ "max_chunks": { "type": "integer", "default": 5 + }, + "include_metadata_in_content": { + "type": "boolean", + "default": false } }, "additionalProperties": false, "required": [ "query_generator_config", "max_tokens_in_context", - "max_chunks" + "max_chunks", + "include_metadata_in_content" ], "title": "RAGQueryConfig" }, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 62e3ca85c..5fdabfb2e 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -7690,11 +7690,15 @@ components: max_chunks: type: integer default: 5 + include_metadata_in_content: + type: boolean + default: false additionalProperties: false required: - query_generator_config - max_tokens_in_context - max_chunks + - include_metadata_in_content title: RAGQueryConfig RAGQueryGeneratorConfig: oneOf: diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index db6303209..663d25f72 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -51,6 +51,7 @@ chunks = [ "mime_type": "text/plain", "metadata": { "document_id": "doc1", + "author": "Jane Doe", }, }, ] @@ -98,6 +99,17 @@ results = client.tool_runtime.rag_tool.query( ) ``` +You can configure adding metadata to the context if you find it useful for your application. Simply add: +```python +# Query documents +results = client.tool_runtime.rag_tool.query( + vector_db_ids=[vector_db_id], + content="What do you know about...", + query_config={ + "include_metadata_in_content": True, + }, +) +``` ### Building RAG-Enhanced Agents One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example: @@ -115,6 +127,12 @@ agent = Agent( "name": "builtin::rag/knowledge_search", "args": { "vector_db_ids": [vector_db_id], + # Defaults + "query_config": { + "chunk_size_in_tokens": 512, + "chunk_overlap_in_tokens": 0, + "include_metadata_in_content": False, + }, }, } ], diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index fdf199b1a..8dc7b7385 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -72,6 +72,7 @@ class RAGQueryConfig(BaseModel): query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig()) max_tokens_in_context: int = 4096 max_chunks: int = 5 + include_metadata_in_content: bool = False @runtime_checkable diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index df0257718..c290e89ec 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -87,6 +87,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): content, chunk_size_in_tokens, chunk_size_in_tokens // 4, + doc.metadata, ) ) @@ -140,19 +141,29 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n" ) ] - for i, c in enumerate(chunks): - metadata = c.metadata + for i, chunk in enumerate(chunks): + metadata = chunk.metadata tokens += metadata["token_count"] + if query_config.include_metadata_in_content: + tokens += metadata["metadata_token_count"] + if tokens > query_config.max_tokens_in_context: log.error( f"Using {len(picked)} chunks; reached max tokens in context: {tokens}", ) break - picked.append( - TextContentItem( - text=f"Result {i + 1}:\nDocument_id:{metadata['document_id'][:5]}\nContent: {c.content}\n", - ) - ) + + text_content = f"Result {i + 1}:\n" + if query_config.include_metadata_in_content: + metadata_subset = { + k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"] + } + text_content += f"\nMetadata: {metadata_subset}" + else: + text_content += f"Document_id:{metadata['document_id'][:5]}" + text_content += f"\nContent: {chunk.content}\n" + picked.append(TextContentItem(text=text_content)) + picked.append(TextContentItem(text="END of knowledge_search tool results.\n")) picked.append( TextContentItem( diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 9d892c166..1bc00d1d3 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -139,22 +139,27 @@ async def content_from_doc(doc: RAGDocument) -> str: return interleaved_content_as_str(doc.content) -def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> list[Chunk]: +def make_overlapped_chunks( + document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any] +) -> list[Chunk]: tokenizer = Tokenizer.get_instance() tokens = tokenizer.encode(text, bos=False, eos=False) + metadata_tokens = tokenizer.encode(str(metadata), bos=False, eos=False) chunks = [] for i in range(0, len(tokens), window_len - overlap_len): toks = tokens[i : i + window_len] chunk = tokenizer.decode(toks) + chunk_metadata = metadata.copy() + chunk_metadata["document_id"] = document_id + chunk_metadata["token_count"] = len(toks) + chunk_metadata["metadata_token_count"] = len(metadata_tokens) + # chunk is a string chunks.append( Chunk( content=chunk, - metadata={ - "token_count": len(toks), - "document_id": document_id, - }, + metadata=chunk_metadata, ) ) diff --git a/tests/integration/tool_runtime/test_rag_tool.py b/tests/integration/tool_runtime/test_rag_tool.py index c49f507a8..3f13fa101 100644 --- a/tests/integration/tool_runtime/test_rag_tool.py +++ b/tests/integration/tool_runtime/test_rag_tool.py @@ -49,7 +49,7 @@ def sample_documents(): ] -def assert_valid_response(response): +def assert_valid_chunk_response(response): assert len(response.chunks) > 0 assert len(response.scores) > 0 assert len(response.chunks) == len(response.scores) @@ -57,6 +57,11 @@ def assert_valid_response(response): assert isinstance(chunk.content, str) +def assert_valid_text_response(response): + assert len(response.content) > 0 + assert all(isinstance(chunk.text, str) for chunk in response.content) + + def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_documents, embedding_model_id): vector_db_id = "test_vector_db" client_with_empty_registry.vector_dbs.register( @@ -77,7 +82,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do vector_db_id=vector_db_id, query=query1, ) - assert_valid_response(response1) + assert_valid_chunk_response(response1) assert any("Python" in chunk.content for chunk in response1.chunks) # Query with semantic similarity @@ -86,7 +91,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do vector_db_id=vector_db_id, query=query2, ) - assert_valid_response(response2) + assert_valid_chunk_response(response2) assert any("neural networks" in chunk.content.lower() for chunk in response2.chunks) # Query with limit on number of results (max_chunks=2) @@ -96,7 +101,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do query=query3, params={"max_chunks": 2}, ) - assert_valid_response(response3) + assert_valid_chunk_response(response3) assert len(response3.chunks) <= 2 # Query with threshold on similarity score @@ -106,7 +111,7 @@ def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_do query=query4, params={"score_threshold": 0.01}, ) - assert_valid_response(response4) + assert_valid_chunk_response(response4) assert all(score >= 0.01 for score in response4.scores) @@ -126,9 +131,6 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_ available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] assert vector_db_id in available_vector_dbs - # URLs of documents to insert - # TODO: Move to test/memory/resources then update the url to - # https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/memory/resources/{url} urls = [ "memory_optimizations.rst", "chat.rst", @@ -139,7 +141,7 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_ document_id=f"num-{i}", content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", mime_type="text/plain", - metadata={}, + metadata={"author": "llama", "source": url}, ) for i, url in enumerate(urls) ] @@ -155,7 +157,7 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_ vector_db_id=vector_db_id, query="What's the name of the fine-tunning method used?", ) - assert_valid_response(response1) + assert_valid_chunk_response(response1) assert any("lora" in chunk.content.lower() for chunk in response1.chunks) # Query for the name of model @@ -163,5 +165,52 @@ def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_ vector_db_id=vector_db_id, query="Which Llama model is mentioned?", ) - assert_valid_response(response2) + assert_valid_chunk_response(response2) assert any("llama2" in chunk.content.lower() for chunk in response2.chunks) + + +def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id): + providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"] + assert len(providers) > 0 + + vector_db_id = "test_vector_db" + + client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_id, + embedding_model=embedding_model_id, + embedding_dimension=384, + ) + + available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] + assert vector_db_id in available_vector_dbs + + urls = [ + "memory_optimizations.rst", + "chat.rst", + "llama3.rst", + ] + documents = [ + Document( + document_id=f"num-{i}", + content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", + mime_type="text/plain", + metadata={"author": "llama", "source": url}, + ) + for i, url in enumerate(urls) + ] + + client_with_empty_registry.tool_runtime.rag_tool.insert( + documents=documents, + vector_db_id=vector_db_id, + chunk_size_in_tokens=512, + ) + + response = client_with_empty_registry.tool_runtime.rag_tool.query( + vector_db_ids=[vector_db_id], + content="What is the name of the method used for fine-tuning?", + query_config={ + "include_metadata_in_content": True, + }, + ) + assert_valid_text_response(response) + assert any("metadata:" in chunk.text.lower() for chunk in response.content) diff --git a/tests/unit/rag/test_vector_store.py b/tests/unit/rag/test_vector_store.py index 3decc431e..2624c058c 100644 --- a/tests/unit/rag/test_vector_store.py +++ b/tests/unit/rag/test_vector_store.py @@ -12,7 +12,7 @@ from pathlib import Path import pytest from llama_stack.apis.tools import RAGDocument -from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc +from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc, make_overlapped_chunks DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf" # Depending on the machine, this can get parsed a couple of ways @@ -76,3 +76,34 @@ class TestVectorStore: ) content = await content_from_doc(doc) assert content in DUMMY_PDF_TEXT_CHOICES + + @pytest.mark.parametrize( + "window_len, overlap_len, expected_chunks", + [ + (5, 2, 4), # Create 4 chunks with window of 5 and overlap of 2 + (4, 1, 4), # Create 4 chunks with window of 3 and overlap of 1 + ], + ) + def test_make_overlapped_chunks(self, window_len, overlap_len, expected_chunks): + document_id = "test_doc_123" + text = "This is a sample document for testing the chunking behavior" + original_metadata = {"source": "test", "date": "2023-01-01", "author": "llama"} + len_metadata_tokens = 24 # specific to the metadata above + + chunks = make_overlapped_chunks(document_id, text, window_len, overlap_len, original_metadata) + + assert len(chunks) == expected_chunks + + # Check that each chunk has the right metadata + for chunk in chunks: + # Original metadata should be preserved + assert chunk.metadata["source"] == "test" + assert chunk.metadata["date"] == "2023-01-01" + assert chunk.metadata["author"] == "llama" + + # New metadata should be added + assert chunk.metadata["document_id"] == document_id + assert "token_count" in chunk.metadata + assert isinstance(chunk.metadata["token_count"], int) + assert chunk.metadata["token_count"] > 0 + assert chunk.metadata["metadata_token_count"] == len_metadata_tokens