diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index a1543457b..ebdd7ad93 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -131,8 +131,18 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti for vector_db_id in vector_db_ids ] results: list[QueryChunksResponse] = await asyncio.gather(*tasks) - chunks = [c for r in results for c in r.chunks] - scores = [s for r in results for s in r.scores] + + chunks = [] + scores = [] + + for vector_db_id, result in zip(vector_db_ids, results, strict=False): + for chunk, score in zip(result.chunks, result.scores, strict=False): + if not hasattr(chunk, "metadata") or chunk.metadata is None: + chunk.metadata = {} + chunk.metadata["vector_db_id"] = vector_db_id + + chunks.append(chunk) + scores.append(score) if not chunks: return RAGQueryResult(content=None) diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py index 05ccecb99..3ce0c5f55 100644 --- a/tests/unit/rag/test_rag_query.py +++ b/tests/unit/rag/test_rag_query.py @@ -54,9 +54,7 @@ class TestRagQuery: result = await rag_tool.query(content=content, vector_db_ids=vector_db_ids) assert result is not None - expected_metadata_string = ( - "Metadata: {'chunk_id': 'chunk1', 'document_id': 'doc1', 'source': 'test_source', 'key1': 'value1'}" - ) + expected_metadata_string = "Metadata: {'chunk_id': 'chunk1', 'document_id': 'doc1', 'source': 'test_source', 'key1': 'value1', 'vector_db_id': 'db1'}" assert expected_metadata_string in result.content[1].text assert result.content is not None @@ -77,3 +75,71 @@ class TestRagQuery: # Test that invalid mode raises an error with pytest.raises(ValueError): RAGQueryConfig(mode="wrong_mode") + + @pytest.mark.asyncio + async def test_query_adds_vector_db_id_to_chunk_metadata(self): + rag_tool = MemoryToolRuntimeImpl( + config=MagicMock(), + vector_io_api=MagicMock(), + inference_api=MagicMock(), + ) + + vector_db_ids = ["db1", "db2"] + + # Fake chunks from each DB + chunk_metadata1 = ChunkMetadata( + document_id="doc1", + chunk_id="chunk1", + source="test_source1", + metadata_token_count=5, + ) + chunk1 = Chunk( + content="chunk from db1", + metadata={"vector_db_id": "db1", "document_id": "doc1"}, + stored_chunk_id="c1", + chunk_metadata=chunk_metadata1, + ) + + chunk_metadata2 = ChunkMetadata( + document_id="doc2", + chunk_id="chunk2", + source="test_source2", + metadata_token_count=5, + ) + chunk2 = Chunk( + content="chunk from db2", + metadata={"vector_db_id": "db2", "document_id": "doc2"}, + stored_chunk_id="c2", + chunk_metadata=chunk_metadata2, + ) + + rag_tool.vector_io_api.query_chunks = AsyncMock( + side_effect=[ + QueryChunksResponse(chunks=[chunk1], scores=[0.9]), + QueryChunksResponse(chunks=[chunk2], scores=[0.8]), + ] + ) + + result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids) + returned_chunks = result.metadata["chunks"] + returned_scores = result.metadata["scores"] + returned_doc_ids = result.metadata["document_ids"] + + assert returned_chunks == ["chunk from db1", "chunk from db2"] + assert returned_scores == (0.9, 0.8) + assert returned_doc_ids == ["doc1", "doc2"] + + # Parse metadata from query result + def parse_metadata(s): + import ast + import re + + match = re.search(r"Metadata:\s*(\{.*\})", s) + if not match: + raise ValueError(f"No metadata found in string: {s}") + return ast.literal_eval(match.group(1)) + + returned_metadata = [ + parse_metadata(item.text)["vector_db_id"] for item in result.content if "Metadata:" in item.text + ] + assert returned_metadata == ["db1", "db2"]