mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-25 09:05:37 +00:00 
			
		
		
		
	# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> Fix pre-commit issues: non executable shebang file, @pytest.mark.asyncio decorator <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* -->
		
			
				
	
	
		
			138 lines
		
	
	
	
		
			5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			138 lines
		
	
	
	
		
			5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| from unittest.mock import AsyncMock, MagicMock
 | |
| 
 | |
| import pytest
 | |
| 
 | |
| from llama_stack.apis.tools.rag_tool import RAGQueryConfig
 | |
| from llama_stack.apis.vector_io import (
 | |
|     Chunk,
 | |
|     ChunkMetadata,
 | |
|     QueryChunksResponse,
 | |
| )
 | |
| from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
 | |
| 
 | |
| 
 | |
| class TestRagQuery:
 | |
|     async def test_query_raises_on_empty_vector_db_ids(self):
 | |
|         rag_tool = MemoryToolRuntimeImpl(
 | |
|             config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
 | |
|         )
 | |
|         with pytest.raises(ValueError):
 | |
|             await rag_tool.query(content=MagicMock(), vector_db_ids=[])
 | |
| 
 | |
|     async def test_query_chunk_metadata_handling(self):
 | |
|         rag_tool = MemoryToolRuntimeImpl(
 | |
|             config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
 | |
|         )
 | |
|         content = "test query content"
 | |
|         vector_db_ids = ["db1"]
 | |
| 
 | |
|         chunk_metadata = ChunkMetadata(
 | |
|             document_id="doc1",
 | |
|             chunk_id="chunk1",
 | |
|             source="test_source",
 | |
|             metadata_token_count=5,
 | |
|         )
 | |
|         interleaved_content = MagicMock()
 | |
|         chunk = Chunk(
 | |
|             content=interleaved_content,
 | |
|             metadata={
 | |
|                 "key1": "value1",
 | |
|                 "token_count": 10,
 | |
|                 "metadata_token_count": 5,
 | |
|                 # Note this is inserted into `metadata` during MemoryToolRuntimeImpl().insert()
 | |
|                 "document_id": "doc1",
 | |
|             },
 | |
|             stored_chunk_id="chunk1",
 | |
|             chunk_metadata=chunk_metadata,
 | |
|         )
 | |
| 
 | |
|         query_response = QueryChunksResponse(chunks=[chunk], scores=[1.0])
 | |
| 
 | |
|         rag_tool.vector_io_api.query_chunks = AsyncMock(return_value=query_response)
 | |
|         result = await rag_tool.query(content=content, vector_db_ids=vector_db_ids)
 | |
| 
 | |
|         assert result is not None
 | |
|         expected_metadata_string = (
 | |
|             "Metadata: {'chunk_id': 'chunk1', 'document_id': 'doc1', 'source': 'test_source', 'key1': 'value1'}"
 | |
|         )
 | |
|         assert expected_metadata_string in result.content[1].text
 | |
|         assert result.content is not None
 | |
| 
 | |
|     async def test_query_raises_incorrect_mode(self):
 | |
|         with pytest.raises(ValueError):
 | |
|             RAGQueryConfig(mode="invalid_mode")
 | |
| 
 | |
|     async def test_query_accepts_valid_modes(self):
 | |
|         default_config = RAGQueryConfig()  # Test default (vector)
 | |
|         assert default_config.mode == "vector"
 | |
|         vector_config = RAGQueryConfig(mode="vector")  # Test vector
 | |
|         assert vector_config.mode == "vector"
 | |
|         keyword_config = RAGQueryConfig(mode="keyword")  # Test keyword
 | |
|         assert keyword_config.mode == "keyword"
 | |
|         hybrid_config = RAGQueryConfig(mode="hybrid")  # Test hybrid
 | |
|         assert hybrid_config.mode == "hybrid"
 | |
| 
 | |
|         # Test that invalid mode raises an error
 | |
|         with pytest.raises(ValueError):
 | |
|             RAGQueryConfig(mode="wrong_mode")
 | |
| 
 | |
|     async def test_query_adds_vector_db_id_to_chunk_metadata(self):
 | |
|         rag_tool = MemoryToolRuntimeImpl(
 | |
|             config=MagicMock(),
 | |
|             vector_io_api=MagicMock(),
 | |
|             inference_api=MagicMock(),
 | |
|             files_api=MagicMock(),
 | |
|         )
 | |
| 
 | |
|         vector_db_ids = ["db1", "db2"]
 | |
| 
 | |
|         # Fake chunks from each DB
 | |
|         chunk_metadata1 = ChunkMetadata(
 | |
|             document_id="doc1",
 | |
|             chunk_id="chunk1",
 | |
|             source="test_source1",
 | |
|             metadata_token_count=5,
 | |
|         )
 | |
|         chunk1 = Chunk(
 | |
|             content="chunk from db1",
 | |
|             metadata={"vector_db_id": "db1", "document_id": "doc1"},
 | |
|             stored_chunk_id="c1",
 | |
|             chunk_metadata=chunk_metadata1,
 | |
|         )
 | |
| 
 | |
|         chunk_metadata2 = ChunkMetadata(
 | |
|             document_id="doc2",
 | |
|             chunk_id="chunk2",
 | |
|             source="test_source2",
 | |
|             metadata_token_count=5,
 | |
|         )
 | |
|         chunk2 = Chunk(
 | |
|             content="chunk from db2",
 | |
|             metadata={"vector_db_id": "db2", "document_id": "doc2"},
 | |
|             stored_chunk_id="c2",
 | |
|             chunk_metadata=chunk_metadata2,
 | |
|         )
 | |
| 
 | |
|         rag_tool.vector_io_api.query_chunks = AsyncMock(
 | |
|             side_effect=[
 | |
|                 QueryChunksResponse(chunks=[chunk1], scores=[0.9]),
 | |
|                 QueryChunksResponse(chunks=[chunk2], scores=[0.8]),
 | |
|             ]
 | |
|         )
 | |
| 
 | |
|         result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids)
 | |
|         returned_chunks = result.metadata["chunks"]
 | |
|         returned_scores = result.metadata["scores"]
 | |
|         returned_doc_ids = result.metadata["document_ids"]
 | |
|         returned_vector_db_ids = result.metadata["vector_db_ids"]
 | |
| 
 | |
|         assert returned_chunks == ["chunk from db1", "chunk from db2"]
 | |
|         assert returned_scores == (0.9, 0.8)
 | |
|         assert returned_doc_ids == ["doc1", "doc2"]
 | |
|         assert returned_vector_db_ids == ["db1", "db2"]
 |