mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat: Add vector_db_id to chunk metadata (#3304)
# What does this PR do? When running RAG in a multi vector DB setting, it can be difficult to trace where retrieved chunks originate from. This PR adds the `vector_db_id` into each chunk’s metadata, making it easier to understand which database a given chunk came from. This is helpful for debugging and for analyzing retrieval behavior of multiple DBs. Relevant code: ```python for vector_db_id, result in zip(vector_db_ids, results): for chunk, score in zip(result.chunks, result.scores): if not hasattr(chunk, "metadata") or chunk.metadata is None: chunk.metadata = {} chunk.metadata["vector_db_id"] = vector_db_id chunks.append(chunk) scores.append(score) ``` ## Test Plan * Ran Llama Stack in debug mode. * Verified that `vector_db_id` was added to each chunk’s metadata. * Confirmed that the metadata was printed in the console when using the RAG tool. --------- Co-authored-by: are-ces <cpompeia@redhat.com> Co-authored-by: Francisco Arceo <arceofrancisco@gmail.com>
This commit is contained in:
parent
81ad240faa
commit
1c23aeb937
2 changed files with 69 additions and 2 deletions
|
@ -167,8 +167,18 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
for vector_db_id in vector_db_ids
|
for vector_db_id in vector_db_ids
|
||||||
]
|
]
|
||||||
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
|
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
|
||||||
chunks = [c for r in results for c in r.chunks]
|
|
||||||
scores = [s for r in results for s in r.scores]
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
|
||||||
|
for vector_db_id, result in zip(vector_db_ids, results, strict=False):
|
||||||
|
for chunk, score in zip(result.chunks, result.scores, strict=False):
|
||||||
|
if not hasattr(chunk, "metadata") or chunk.metadata is None:
|
||||||
|
chunk.metadata = {}
|
||||||
|
chunk.metadata["vector_db_id"] = vector_db_id
|
||||||
|
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return RAGQueryResult(content=None)
|
return RAGQueryResult(content=None)
|
||||||
|
@ -203,6 +213,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
metadata_keys_to_exclude_from_context = [
|
metadata_keys_to_exclude_from_context = [
|
||||||
"token_count",
|
"token_count",
|
||||||
"metadata_token_count",
|
"metadata_token_count",
|
||||||
|
"vector_db_id",
|
||||||
]
|
]
|
||||||
metadata_for_context = {}
|
metadata_for_context = {}
|
||||||
for k in chunk_metadata_keys_to_include_from_context:
|
for k in chunk_metadata_keys_to_include_from_context:
|
||||||
|
@ -227,6 +238,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
||||||
"chunks": [c.content for c in chunks[: len(picked)]],
|
"chunks": [c.content for c in chunks[: len(picked)]],
|
||||||
"scores": scores[: len(picked)],
|
"scores": scores[: len(picked)],
|
||||||
|
"vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -81,3 +81,58 @@ class TestRagQuery:
|
||||||
# Test that invalid mode raises an error
|
# Test that invalid mode raises an error
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
RAGQueryConfig(mode="wrong_mode")
|
RAGQueryConfig(mode="wrong_mode")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_adds_vector_db_id_to_chunk_metadata(self):
|
||||||
|
rag_tool = MemoryToolRuntimeImpl(
|
||||||
|
config=MagicMock(),
|
||||||
|
vector_io_api=MagicMock(),
|
||||||
|
inference_api=MagicMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_db_ids = ["db1", "db2"]
|
||||||
|
|
||||||
|
# Fake chunks from each DB
|
||||||
|
chunk_metadata1 = ChunkMetadata(
|
||||||
|
document_id="doc1",
|
||||||
|
chunk_id="chunk1",
|
||||||
|
source="test_source1",
|
||||||
|
metadata_token_count=5,
|
||||||
|
)
|
||||||
|
chunk1 = Chunk(
|
||||||
|
content="chunk from db1",
|
||||||
|
metadata={"vector_db_id": "db1", "document_id": "doc1"},
|
||||||
|
stored_chunk_id="c1",
|
||||||
|
chunk_metadata=chunk_metadata1,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_metadata2 = ChunkMetadata(
|
||||||
|
document_id="doc2",
|
||||||
|
chunk_id="chunk2",
|
||||||
|
source="test_source2",
|
||||||
|
metadata_token_count=5,
|
||||||
|
)
|
||||||
|
chunk2 = Chunk(
|
||||||
|
content="chunk from db2",
|
||||||
|
metadata={"vector_db_id": "db2", "document_id": "doc2"},
|
||||||
|
stored_chunk_id="c2",
|
||||||
|
chunk_metadata=chunk_metadata2,
|
||||||
|
)
|
||||||
|
|
||||||
|
rag_tool.vector_io_api.query_chunks = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
QueryChunksResponse(chunks=[chunk1], scores=[0.9]),
|
||||||
|
QueryChunksResponse(chunks=[chunk2], scores=[0.8]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids)
|
||||||
|
returned_chunks = result.metadata["chunks"]
|
||||||
|
returned_scores = result.metadata["scores"]
|
||||||
|
returned_doc_ids = result.metadata["document_ids"]
|
||||||
|
returned_vector_db_ids = result.metadata["vector_db_ids"]
|
||||||
|
|
||||||
|
assert returned_chunks == ["chunk from db1", "chunk from db2"]
|
||||||
|
assert returned_scores == (0.9, 0.8)
|
||||||
|
assert returned_doc_ids == ["doc1", "doc2"]
|
||||||
|
assert returned_vector_db_ids == ["db1", "db2"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue