From 1c23aeb9372fa3e1286a7d6d8210994000efae6d Mon Sep 17 00:00:00 2001 From: Cesare Pompeiano Date: Wed, 10 Sep 2025 11:19:21 +0200 Subject: [PATCH] feat: Add vector_db_id to chunk metadata (#3304) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? When running RAG in a multi vector DB setting, it can be difficult to trace where retrieved chunks originate from. This PR adds the `vector_db_id` into each chunk’s metadata, making it easier to understand which database a given chunk came from. This is helpful for debugging and for analyzing retrieval behavior of multiple DBs. Relevant code: ```python for vector_db_id, result in zip(vector_db_ids, results): for chunk, score in zip(result.chunks, result.scores): if not hasattr(chunk, "metadata") or chunk.metadata is None: chunk.metadata = {} chunk.metadata["vector_db_id"] = vector_db_id chunks.append(chunk) scores.append(score) ``` ## Test Plan * Ran Llama Stack in debug mode. * Verified that `vector_db_id` was added to each chunk’s metadata. * Confirmed that the metadata was printed in the console when using the RAG tool. --------- Co-authored-by: are-ces Co-authored-by: Francisco Arceo --- .../inline/tool_runtime/rag/memory.py | 16 +++++- tests/unit/rag/test_rag_query.py | 55 +++++++++++++++++++ 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index cb526e8ee..aa629cca8 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -167,8 +167,18 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti for vector_db_id in vector_db_ids ] results: list[QueryChunksResponse] = await asyncio.gather(*tasks) - chunks = [c for r in results for c in r.chunks] - scores = [s for r in results for s in r.scores] + + chunks = [] + scores = [] + + for vector_db_id, result in zip(vector_db_ids, results, strict=False): + for chunk, score in zip(result.chunks, result.scores, strict=False): + if not hasattr(chunk, "metadata") or chunk.metadata is None: + chunk.metadata = {} + chunk.metadata["vector_db_id"] = vector_db_id + + chunks.append(chunk) + scores.append(score) if not chunks: return RAGQueryResult(content=None) @@ -203,6 +213,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti metadata_keys_to_exclude_from_context = [ "token_count", "metadata_token_count", + "vector_db_id", ] metadata_for_context = {} for k in chunk_metadata_keys_to_include_from_context: @@ -227,6 +238,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti "document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]], "chunks": [c.content for c in chunks[: len(picked)]], "scores": scores[: len(picked)], + "vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]], }, ) diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py index d18d90716..7b897bfe0 100644 --- a/tests/unit/rag/test_rag_query.py +++ b/tests/unit/rag/test_rag_query.py @@ -81,3 +81,58 @@ class TestRagQuery: # Test that invalid mode raises an error with pytest.raises(ValueError): RAGQueryConfig(mode="wrong_mode") + + @pytest.mark.asyncio + async def test_query_adds_vector_db_id_to_chunk_metadata(self): + rag_tool = MemoryToolRuntimeImpl( + config=MagicMock(), + vector_io_api=MagicMock(), + inference_api=MagicMock(), + ) + + vector_db_ids = ["db1", "db2"] + + # Fake chunks from each DB + chunk_metadata1 = ChunkMetadata( + document_id="doc1", + chunk_id="chunk1", + source="test_source1", + metadata_token_count=5, + ) + chunk1 = Chunk( + content="chunk from db1", + metadata={"vector_db_id": "db1", "document_id": "doc1"}, + stored_chunk_id="c1", + chunk_metadata=chunk_metadata1, + ) + + chunk_metadata2 = ChunkMetadata( + document_id="doc2", + chunk_id="chunk2", + source="test_source2", + metadata_token_count=5, + ) + chunk2 = Chunk( + content="chunk from db2", + metadata={"vector_db_id": "db2", "document_id": "doc2"}, + stored_chunk_id="c2", + chunk_metadata=chunk_metadata2, + ) + + rag_tool.vector_io_api.query_chunks = AsyncMock( + side_effect=[ + QueryChunksResponse(chunks=[chunk1], scores=[0.9]), + QueryChunksResponse(chunks=[chunk2], scores=[0.8]), + ] + ) + + result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids) + returned_chunks = result.metadata["chunks"] + returned_scores = result.metadata["scores"] + returned_doc_ids = result.metadata["document_ids"] + returned_vector_db_ids = result.metadata["vector_db_ids"] + + assert returned_chunks == ["chunk from db1", "chunk from db2"] + assert returned_scores == (0.9, 0.8) + assert returned_doc_ids == ["doc1", "doc2"] + assert returned_vector_db_ids == ["db1", "db2"]