mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-24 16:57:21 +00:00 
			
		
		
		
	revert: "chore(cleanup)!: remove tool_runtime.rag_tool" (#3877)
Reverts llamastack/llama-stack#3871 This PR broke RAG (even from Responses -- there _is_ a dependency)
This commit is contained in:
		
							parent
							
								
									eb3e9b85f9
								
							
						
					
					
						commit
						bd3c473208
					
				
					 55 changed files with 3114 additions and 17 deletions
				
			
		
							
								
								
									
										138
									
								
								tests/unit/rag/test_rag_query.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								tests/unit/rag/test_rag_query.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,138 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| from unittest.mock import AsyncMock, MagicMock | ||||
| 
 | ||||
| import pytest | ||||
| 
 | ||||
| from llama_stack.apis.tools.rag_tool import RAGQueryConfig | ||||
| from llama_stack.apis.vector_io import ( | ||||
|     Chunk, | ||||
|     ChunkMetadata, | ||||
|     QueryChunksResponse, | ||||
| ) | ||||
| from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl | ||||
| 
 | ||||
| 
 | ||||
| class TestRagQuery: | ||||
|     async def test_query_raises_on_empty_vector_store_ids(self): | ||||
|         rag_tool = MemoryToolRuntimeImpl( | ||||
|             config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock() | ||||
|         ) | ||||
|         with pytest.raises(ValueError): | ||||
|             await rag_tool.query(content=MagicMock(), vector_db_ids=[]) | ||||
| 
 | ||||
|     async def test_query_chunk_metadata_handling(self): | ||||
|         rag_tool = MemoryToolRuntimeImpl( | ||||
|             config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock() | ||||
|         ) | ||||
|         content = "test query content" | ||||
|         vector_db_ids = ["db1"] | ||||
| 
 | ||||
|         chunk_metadata = ChunkMetadata( | ||||
|             document_id="doc1", | ||||
|             chunk_id="chunk1", | ||||
|             source="test_source", | ||||
|             metadata_token_count=5, | ||||
|         ) | ||||
|         interleaved_content = MagicMock() | ||||
|         chunk = Chunk( | ||||
|             content=interleaved_content, | ||||
|             metadata={ | ||||
|                 "key1": "value1", | ||||
|                 "token_count": 10, | ||||
|                 "metadata_token_count": 5, | ||||
|                 # Note this is inserted into `metadata` during MemoryToolRuntimeImpl().insert() | ||||
|                 "document_id": "doc1", | ||||
|             }, | ||||
|             stored_chunk_id="chunk1", | ||||
|             chunk_metadata=chunk_metadata, | ||||
|         ) | ||||
| 
 | ||||
|         query_response = QueryChunksResponse(chunks=[chunk], scores=[1.0]) | ||||
| 
 | ||||
|         rag_tool.vector_io_api.query_chunks = AsyncMock(return_value=query_response) | ||||
|         result = await rag_tool.query(content=content, vector_db_ids=vector_db_ids) | ||||
| 
 | ||||
|         assert result is not None | ||||
|         expected_metadata_string = ( | ||||
|             "Metadata: {'chunk_id': 'chunk1', 'document_id': 'doc1', 'source': 'test_source', 'key1': 'value1'}" | ||||
|         ) | ||||
|         assert expected_metadata_string in result.content[1].text | ||||
|         assert result.content is not None | ||||
| 
 | ||||
|     async def test_query_raises_incorrect_mode(self): | ||||
|         with pytest.raises(ValueError): | ||||
|             RAGQueryConfig(mode="invalid_mode") | ||||
| 
 | ||||
|     async def test_query_accepts_valid_modes(self): | ||||
|         default_config = RAGQueryConfig()  # Test default (vector) | ||||
|         assert default_config.mode == "vector" | ||||
|         vector_config = RAGQueryConfig(mode="vector")  # Test vector | ||||
|         assert vector_config.mode == "vector" | ||||
|         keyword_config = RAGQueryConfig(mode="keyword")  # Test keyword | ||||
|         assert keyword_config.mode == "keyword" | ||||
|         hybrid_config = RAGQueryConfig(mode="hybrid")  # Test hybrid | ||||
|         assert hybrid_config.mode == "hybrid" | ||||
| 
 | ||||
|         # Test that invalid mode raises an error | ||||
|         with pytest.raises(ValueError): | ||||
|             RAGQueryConfig(mode="wrong_mode") | ||||
| 
 | ||||
|     async def test_query_adds_vector_store_id_to_chunk_metadata(self): | ||||
|         rag_tool = MemoryToolRuntimeImpl( | ||||
|             config=MagicMock(), | ||||
|             vector_io_api=MagicMock(), | ||||
|             inference_api=MagicMock(), | ||||
|             files_api=MagicMock(), | ||||
|         ) | ||||
| 
 | ||||
|         vector_db_ids = ["db1", "db2"] | ||||
| 
 | ||||
|         # Fake chunks from each DB | ||||
|         chunk_metadata1 = ChunkMetadata( | ||||
|             document_id="doc1", | ||||
|             chunk_id="chunk1", | ||||
|             source="test_source1", | ||||
|             metadata_token_count=5, | ||||
|         ) | ||||
|         chunk1 = Chunk( | ||||
|             content="chunk from db1", | ||||
|             metadata={"vector_db_id": "db1", "document_id": "doc1"}, | ||||
|             stored_chunk_id="c1", | ||||
|             chunk_metadata=chunk_metadata1, | ||||
|         ) | ||||
| 
 | ||||
|         chunk_metadata2 = ChunkMetadata( | ||||
|             document_id="doc2", | ||||
|             chunk_id="chunk2", | ||||
|             source="test_source2", | ||||
|             metadata_token_count=5, | ||||
|         ) | ||||
|         chunk2 = Chunk( | ||||
|             content="chunk from db2", | ||||
|             metadata={"vector_db_id": "db2", "document_id": "doc2"}, | ||||
|             stored_chunk_id="c2", | ||||
|             chunk_metadata=chunk_metadata2, | ||||
|         ) | ||||
| 
 | ||||
|         rag_tool.vector_io_api.query_chunks = AsyncMock( | ||||
|             side_effect=[ | ||||
|                 QueryChunksResponse(chunks=[chunk1], scores=[0.9]), | ||||
|                 QueryChunksResponse(chunks=[chunk2], scores=[0.8]), | ||||
|             ] | ||||
|         ) | ||||
| 
 | ||||
|         result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids) | ||||
|         returned_chunks = result.metadata["chunks"] | ||||
|         returned_scores = result.metadata["scores"] | ||||
|         returned_doc_ids = result.metadata["document_ids"] | ||||
|         returned_vector_db_ids = result.metadata["vector_db_ids"] | ||||
| 
 | ||||
|         assert returned_chunks == ["chunk from db1", "chunk from db2"] | ||||
|         assert returned_scores == (0.9, 0.8) | ||||
|         assert returned_doc_ids == ["doc1", "doc2"] | ||||
|         assert returned_vector_db_ids == ["db1", "db2"] | ||||
|  | @ -4,6 +4,10 @@ | |||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| import base64 | ||||
| import mimetypes | ||||
| import os | ||||
| from pathlib import Path | ||||
| from unittest.mock import AsyncMock, MagicMock | ||||
| 
 | ||||
| import numpy as np | ||||
|  | @ -13,13 +17,37 @@ from llama_stack.apis.inference.inference import ( | |||
|     OpenAIEmbeddingData, | ||||
|     OpenAIEmbeddingsRequestWithExtraBody, | ||||
| ) | ||||
| from llama_stack.apis.tools import RAGDocument | ||||
| from llama_stack.apis.vector_io import Chunk | ||||
| from llama_stack.providers.utils.memory.vector_store import ( | ||||
|     URL, | ||||
|     VectorStoreWithIndex, | ||||
|     _validate_embedding, | ||||
|     content_from_doc, | ||||
|     make_overlapped_chunks, | ||||
| ) | ||||
| 
 | ||||
| DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf" | ||||
| # Depending on the machine, this can get parsed a couple of ways | ||||
| DUMMY_PDF_TEXT_CHOICES = ["Dummy PDF file", "Dumm y PDF file"] | ||||
| 
 | ||||
| 
 | ||||
| def read_file(file_path: str) -> bytes: | ||||
|     with open(file_path, "rb") as file: | ||||
|         return file.read() | ||||
| 
 | ||||
| 
 | ||||
| def data_url_from_file(file_path: str) -> str: | ||||
|     with open(file_path, "rb") as file: | ||||
|         file_content = file.read() | ||||
| 
 | ||||
|     base64_content = base64.b64encode(file_content).decode("utf-8") | ||||
|     mime_type, _ = mimetypes.guess_type(file_path) | ||||
| 
 | ||||
|     data_url = f"data:{mime_type};base64,{base64_content}" | ||||
| 
 | ||||
|     return data_url | ||||
| 
 | ||||
| 
 | ||||
| class TestChunk: | ||||
|     def test_chunk(self): | ||||
|  | @ -88,6 +116,45 @@ class TestValidateEmbedding: | |||
| 
 | ||||
| 
 | ||||
| class TestVectorStore: | ||||
|     async def test_returns_content_from_pdf_data_uri(self): | ||||
|         data_uri = data_url_from_file(DUMMY_PDF_PATH) | ||||
|         doc = RAGDocument( | ||||
|             document_id="dummy", | ||||
|             content=data_uri, | ||||
|             mime_type="application/pdf", | ||||
|             metadata={}, | ||||
|         ) | ||||
|         content = await content_from_doc(doc) | ||||
|         assert content in DUMMY_PDF_TEXT_CHOICES | ||||
| 
 | ||||
|     @pytest.mark.allow_network | ||||
|     async def test_downloads_pdf_and_returns_content(self): | ||||
|         # Using GitHub to host the PDF file | ||||
|         url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf" | ||||
|         doc = RAGDocument( | ||||
|             document_id="dummy", | ||||
|             content=url, | ||||
|             mime_type="application/pdf", | ||||
|             metadata={}, | ||||
|         ) | ||||
|         content = await content_from_doc(doc) | ||||
|         assert content in DUMMY_PDF_TEXT_CHOICES | ||||
| 
 | ||||
|     @pytest.mark.allow_network | ||||
|     async def test_downloads_pdf_and_returns_content_with_url_object(self): | ||||
|         # Using GitHub to host the PDF file | ||||
|         url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf" | ||||
|         doc = RAGDocument( | ||||
|             document_id="dummy", | ||||
|             content=URL( | ||||
|                 uri=url, | ||||
|             ), | ||||
|             mime_type="application/pdf", | ||||
|             metadata={}, | ||||
|         ) | ||||
|         content = await content_from_doc(doc) | ||||
|         assert content in DUMMY_PDF_TEXT_CHOICES | ||||
| 
 | ||||
|     @pytest.mark.parametrize( | ||||
|         "window_len, overlap_len, expected_chunks", | ||||
|         [ | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue