mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-07 04:45:44 +00:00
Fixed issues with metadata
This commit is contained in:
parent
030de4bbc2
commit
8a59cb3707
2 changed files with 6 additions and 16 deletions
|
@ -202,6 +202,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
||||||
"chunks": [c.content for c in chunks[: len(picked)]],
|
"chunks": [c.content for c in chunks[: len(picked)]],
|
||||||
"scores": scores[: len(picked)],
|
"scores": scores[: len(picked)],
|
||||||
|
"vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,9 @@ class TestRagQuery:
|
||||||
result = await rag_tool.query(content=content, vector_db_ids=vector_db_ids)
|
result = await rag_tool.query(content=content, vector_db_ids=vector_db_ids)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
expected_metadata_string = "Metadata: {'chunk_id': 'chunk1', 'document_id': 'doc1', 'source': 'test_source', 'key1': 'value1', 'vector_db_id': 'db1'}"
|
expected_metadata_string = (
|
||||||
|
"Metadata: {'chunk_id': 'chunk1', 'document_id': 'doc1', 'source': 'test_source', 'key1': 'value1'}"
|
||||||
|
)
|
||||||
assert expected_metadata_string in result.content[1].text
|
assert expected_metadata_string in result.content[1].text
|
||||||
assert result.content is not None
|
assert result.content is not None
|
||||||
|
|
||||||
|
@ -124,22 +126,9 @@ class TestRagQuery:
|
||||||
returned_chunks = result.metadata["chunks"]
|
returned_chunks = result.metadata["chunks"]
|
||||||
returned_scores = result.metadata["scores"]
|
returned_scores = result.metadata["scores"]
|
||||||
returned_doc_ids = result.metadata["document_ids"]
|
returned_doc_ids = result.metadata["document_ids"]
|
||||||
|
returned_vector_db_ids = result.metadata["vector_db_ids"]
|
||||||
|
|
||||||
assert returned_chunks == ["chunk from db1", "chunk from db2"]
|
assert returned_chunks == ["chunk from db1", "chunk from db2"]
|
||||||
assert returned_scores == (0.9, 0.8)
|
assert returned_scores == (0.9, 0.8)
|
||||||
assert returned_doc_ids == ["doc1", "doc2"]
|
assert returned_doc_ids == ["doc1", "doc2"]
|
||||||
|
assert returned_vector_db_ids == ["db1", "db2"]
|
||||||
# Parse metadata from query result
|
|
||||||
def parse_metadata(s):
|
|
||||||
import ast
|
|
||||||
import re
|
|
||||||
|
|
||||||
match = re.search(r"Metadata:\s*(\{.*\})", s)
|
|
||||||
if not match:
|
|
||||||
raise ValueError(f"No metadata found in string: {s}")
|
|
||||||
return ast.literal_eval(match.group(1))
|
|
||||||
|
|
||||||
returned_metadata = [
|
|
||||||
parse_metadata(item.text)["vector_db_id"] for item in result.content if "Metadata:" in item.text
|
|
||||||
]
|
|
||||||
assert returned_metadata == ["db1", "db2"]
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue