diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index c0d80172e..6a7c7885c 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -149,9 +149,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti ] for i, chunk in enumerate(chunks): metadata = chunk.metadata - # update chunk.metadata with the chunk.chunk_metadata if it exists - if chunk.chunk_metadata: - metadata = {**metadata, **chunk.chunk_metadata.dict()} tokens += metadata.get("token_count", 0) tokens += metadata.get("metadata_token_count", 0) @@ -161,21 +158,24 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti ) break - metadata_fields_to_exclude_from_context = [ - "created_timestamp", - "updated_timestamp", - "chunk_window", - "chunk_tokenizer", - "chunk_embedding_model", - "chunk_embedding_dimension", + # Add useful keys from chunk_metadata to metadata and remove some from metadata + chunk_metadata_keys_to_include_from_context = [ + "chunk_id", + "document_id", + "source", + ] + metadata_keys_to_exclude_from_context = [ "token_count", - "content_token_count", "metadata_token_count", ] - metadata_subset = { - k: v for k, v in metadata.items() if k not in metadata_fields_to_exclude_from_context and v - } - text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset) + metadata_for_context = {} + for k in chunk_metadata_keys_to_include_from_context: + metadata_for_context[k] = getattr(chunk.chunk_metadata, k) + for k in metadata: + if k not in metadata_keys_to_exclude_from_context: + metadata_for_context[k] = metadata[k] + + text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_for_context) picked.append(TextContentItem(text=text_content)) picked.append(TextContentItem(text="END of knowledge_search tool results.\n")) diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 835bec90a..3b3c5f486 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -399,13 +399,7 @@ class SQLiteVecIndex(EmbeddingIndex): filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold] # Create a map of chunk_id to chunk for both responses - chunk_map = {} - for c in vector_response.chunks: - chunk_id = c.chunk_id - chunk_map[chunk_id] = c - for c in keyword_response.chunks: - chunk_id = c.chunk_id - chunk_map[chunk_id] = c + chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks} # Use the map to look up chunks by their IDs chunks = [] diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py index 9a24cff1b..d2dd1783b 100644 --- a/tests/unit/rag/test_rag_query.py +++ b/tests/unit/rag/test_rag_query.py @@ -56,7 +56,7 @@ class TestRagQuery: assert result is not None expected_metadata_string = ( - "Metadata: {'key1': 'value1', 'document_id': 'doc1', 'chunk_id': 'chunk1', 'source': 'test_source'}" + "Metadata: {'chunk_id': 'chunk1', 'document_id': 'doc1', 'source': 'test_source', 'key1': 'value1'}" ) assert expected_metadata_string in result.content[1].text assert result.content is not None