From d15368a3026450d1474f4a4db47b89fd3e6057ca Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Thu, 11 Sep 2025 06:20:11 -0600 Subject: [PATCH] chore: Updating documentation, adding exception handling for Vector Stores in RAG Tool, more tests on migration, and migrate off of inference_api for context_retriever for RAG (#3367) # What does this PR do? - Updating documentation on migration from RAG Tool to Vector Stores and Files APIs - Adding exception handling for Vector Stores in RAG Tool - Add more tests on migration from RAG Tool to Vector Stores - Migrate off of inference_api for context_retriever for RAG ## Test Plan Integration and unit tests added Signed-off-by: Francisco Javier Arceo --- docs/source/building_applications/rag.md | 21 ++ .../tool_runtime/rag/context_retriever.py | 12 +- .../inline/tool_runtime/rag/memory.py | 121 ++++++---- .../integration/tool_runtime/test_rag_tool.py | 208 ++++++++++++++++++ .../utils/memory/test_vector_store.py | 38 ++++ 5 files changed, 355 insertions(+), 45 deletions(-) diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index 289c38991..802859e87 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -93,10 +93,31 @@ chunks_response = client.vector_io.query( ### Using the RAG Tool +> **⚠️ DEPRECATION NOTICE**: The RAG Tool is being deprecated in favor of directly using the OpenAI-compatible Search +> API. We recommend migrating to the OpenAI APIs for better compatibility and future support. + A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc. and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the [appendix](#more-ragdocument-examples). +#### OpenAI API Integration & Migration + +The RAG tool has been updated to use OpenAI-compatible APIs. This provides several benefits: + +- **Files API Integration**: Documents are now uploaded using OpenAI's file upload endpoints +- **Vector Stores API**: Vector storage operations use OpenAI's vector store format with configurable chunking strategies +- **Error Resilience:** When processing multiple documents, individual failures are logged but don't crash the operation. Failed documents are skipped while successful ones continue processing. + +**Migration Path:** +We recommend migrating to the OpenAI-compatible Search API for: +1. **Better OpenAI Ecosystem Integration**: Direct compatibility with OpenAI tools and workflows including the Responses API +2**Future-Proof**: Continued support and feature development +3**Full OpenAI Compatibility**: Vector Stores, Files, and Search APIs are fully compatible with OpenAI's Responses API + +The OpenAI APIs are used under the hood, so you can continue to use your existing RAG Tool code with minimal changes. +However, we recommend updating your code to use the new OpenAI-compatible APIs for better long-term support. If any +documents fail to process, they will be logged in the response but will not cause the entire operation to fail. + ```python from llama_stack_client import RAGDocument diff --git a/llama_stack/providers/inline/tool_runtime/rag/context_retriever.py b/llama_stack/providers/inline/tool_runtime/rag/context_retriever.py index be18430e4..9bc22f979 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/context_retriever.py +++ b/llama_stack/providers/inline/tool_runtime/rag/context_retriever.py @@ -8,7 +8,7 @@ from jinja2 import Template from llama_stack.apis.common.content_types import InterleavedContent -from llama_stack.apis.inference import UserMessage +from llama_stack.apis.inference import OpenAIUserMessageParam from llama_stack.apis.tools.rag_tool import ( DefaultRAGQueryGeneratorConfig, LLMRAGQueryGeneratorConfig, @@ -61,16 +61,16 @@ async def llm_rag_query_generator( messages = [interleaved_content_as_str(content)] template = Template(config.template) - content = template.render({"messages": messages}) + rendered_content: str = template.render({"messages": messages}) model = config.model - message = UserMessage(content=content) - response = await inference_api.chat_completion( - model_id=model, + message = OpenAIUserMessageParam(content=rendered_content) + response = await inference_api.openai_chat_completion( + model=model, messages=[message], stream=False, ) - query = response.completion_message.content + query = response.choices[0].message.content return query diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index aa629cca8..bc68f198d 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -45,10 +45,7 @@ from llama_stack.apis.vector_io import ( from llama_stack.log import get_logger from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str -from llama_stack.providers.utils.memory.vector_store import ( - content_from_doc, - parse_data_url, -) +from llama_stack.providers.utils.memory.vector_store import parse_data_url from .config import RagToolRuntimeConfig from .context_retriever import generate_rag_query @@ -60,6 +57,47 @@ def make_random_string(length: int = 8): return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) +async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]: + """Get raw binary data and mime type from a RAGDocument for file upload.""" + if isinstance(doc.content, URL): + if doc.content.uri.startswith("data:"): + parts = parse_data_url(doc.content.uri) + mime_type = parts["mimetype"] + data = parts["data"] + + if parts["is_base64"]: + file_data = base64.b64decode(data) + else: + file_data = data.encode("utf-8") + + return file_data, mime_type + else: + async with httpx.AsyncClient() as client: + r = await client.get(doc.content.uri) + r.raise_for_status() + mime_type = r.headers.get("content-type", "application/octet-stream") + return r.content, mime_type + else: + if isinstance(doc.content, str): + content_str = doc.content + else: + content_str = interleaved_content_as_str(doc.content) + + if content_str.startswith("data:"): + parts = parse_data_url(content_str) + mime_type = parts["mimetype"] + data = parts["data"] + + if parts["is_base64"]: + file_data = base64.b64decode(data) + else: + file_data = data.encode("utf-8") + + return file_data, mime_type + else: + return content_str.encode("utf-8"), "text/plain" + + class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime): def __init__( self, @@ -95,46 +133,52 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti return for doc in documents: - if isinstance(doc.content, URL): - if doc.content.uri.startswith("data:"): - parts = parse_data_url(doc.content.uri) - file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode() - mime_type = parts["mimetype"] - else: - async with httpx.AsyncClient() as client: - response = await client.get(doc.content.uri) - file_data = response.content - mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream") - else: - content_str = await content_from_doc(doc) - file_data = content_str.encode("utf-8") - mime_type = doc.mime_type or "text/plain" + try: + try: + file_data, mime_type = await raw_data_from_doc(doc) + except Exception as e: + log.error(f"Failed to extract content from document {doc.document_id}: {e}") + continue - file_extension = mimetypes.guess_extension(mime_type) or ".txt" - filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}") + file_extension = mimetypes.guess_extension(mime_type) or ".txt" + filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}") - file_obj = io.BytesIO(file_data) - file_obj.name = filename + file_obj = io.BytesIO(file_data) + file_obj.name = filename - upload_file = UploadFile(file=file_obj, filename=filename) + upload_file = UploadFile(file=file_obj, filename=filename) - created_file = await self.files_api.openai_upload_file( - file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS - ) + try: + created_file = await self.files_api.openai_upload_file( + file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS + ) + except Exception as e: + log.error(f"Failed to upload file for document {doc.document_id}: {e}") + continue - chunking_strategy = VectorStoreChunkingStrategyStatic( - static=VectorStoreChunkingStrategyStaticConfig( - max_chunk_size_tokens=chunk_size_in_tokens, - chunk_overlap_tokens=chunk_size_in_tokens // 4, + chunking_strategy = VectorStoreChunkingStrategyStatic( + static=VectorStoreChunkingStrategyStaticConfig( + max_chunk_size_tokens=chunk_size_in_tokens, + chunk_overlap_tokens=chunk_size_in_tokens // 4, + ) ) - ) - await self.vector_io_api.openai_attach_file_to_vector_store( - vector_store_id=vector_db_id, - file_id=created_file.id, - attributes=doc.metadata, - chunking_strategy=chunking_strategy, - ) + try: + await self.vector_io_api.openai_attach_file_to_vector_store( + vector_store_id=vector_db_id, + file_id=created_file.id, + attributes=doc.metadata, + chunking_strategy=chunking_strategy, + ) + except Exception as e: + log.error( + f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}" + ) + continue + + except Exception as e: + log.error(f"Unexpected error processing document {doc.document_id}: {e}") + continue async def query( self, @@ -274,7 +318,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti if query_config: query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config) else: - # handle someone passing an empty dict query_config = RAGQueryConfig() query = kwargs["query"] @@ -285,6 +328,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti ) return ToolInvocationResult( - content=result.content, + content=result.content or [], metadata=result.metadata, ) diff --git a/tests/integration/tool_runtime/test_rag_tool.py b/tests/integration/tool_runtime/test_rag_tool.py index b208500d8..b78c39af8 100644 --- a/tests/integration/tool_runtime/test_rag_tool.py +++ b/tests/integration/tool_runtime/test_rag_tool.py @@ -183,6 +183,110 @@ def test_vector_db_insert_from_url_and_query( assert any("llama2" in chunk.content.lower() for chunk in response2.chunks) +def test_rag_tool_openai_apis(client_with_empty_registry, embedding_model_id, embedding_dimension): + vector_db_id = "test_openai_vector_db" + + client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_id, + embedding_model=embedding_model_id, + embedding_dimension=embedding_dimension, + ) + + available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] + actual_vector_db_id = available_vector_dbs[0] + + # different document formats that should work with OpenAI APIs + documents = [ + Document( + document_id="text-doc", + content="This is a plain text document about machine learning algorithms.", + metadata={"type": "text", "category": "AI"}, + ), + Document( + document_id="url-doc", + content="https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst", + mime_type="text/plain", + metadata={"type": "url", "source": "pytorch"}, + ), + Document( + document_id="data-url-doc", + content="data:text/plain;base64,VGhpcyBpcyBhIGRhdGEgVVJMIGRvY3VtZW50IGFib3V0IGRlZXAgbGVhcm5pbmcu", # "This is a data URL document about deep learning." + metadata={"type": "data_url", "encoding": "base64"}, + ), + ] + + client_with_empty_registry.tool_runtime.rag_tool.insert( + documents=documents, + vector_db_id=actual_vector_db_id, + chunk_size_in_tokens=256, + ) + + files_list = client_with_empty_registry.files.list() + assert len(files_list.data) >= len(documents), ( + f"Expected at least {len(documents)} files, got {len(files_list.data)}" + ) + + vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store( + vector_store_id=actual_vector_db_id + ) + assert len(vector_store_files.data) >= len(documents), f"Expected at least {len(documents)} files in vector store" + + response = client_with_empty_registry.tool_runtime.rag_tool.query( + vector_db_ids=[actual_vector_db_id], + content="Tell me about machine learning and deep learning", + ) + + assert_valid_text_response(response) + content_text = " ".join([chunk.text for chunk in response.content]).lower() + assert "machine learning" in content_text or "deep learning" in content_text + + +def test_rag_tool_exception_handling(client_with_empty_registry, embedding_model_id, embedding_dimension): + vector_db_id = "test_exception_handling" + + client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_id, + embedding_model=embedding_model_id, + embedding_dimension=embedding_dimension, + ) + + available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] + actual_vector_db_id = available_vector_dbs[0] + + documents = [ + Document( + document_id="valid-doc", + content="This is a valid document that should be processed successfully.", + metadata={"status": "valid"}, + ), + Document( + document_id="invalid-url-doc", + content="https://nonexistent-domain-12345.com/invalid.txt", + metadata={"status": "invalid_url"}, + ), + Document( + document_id="another-valid-doc", + content="This is another valid document for testing resilience.", + metadata={"status": "valid"}, + ), + ] + + client_with_empty_registry.tool_runtime.rag_tool.insert( + documents=documents, + vector_db_id=actual_vector_db_id, + chunk_size_in_tokens=256, + ) + + response = client_with_empty_registry.tool_runtime.rag_tool.query( + vector_db_ids=[actual_vector_db_id], + content="valid document", + ) + + assert_valid_text_response(response) + content_text = " ".join([chunk.text for chunk in response.content]).lower() + assert "valid document" in content_text + + def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id, embedding_dimension): providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"] assert len(providers) > 0 @@ -249,3 +353,107 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i "chunk_template": "This should raise a ValueError because it is missing the proper template variables", }, ) + + +def test_rag_tool_query_generation(client_with_empty_registry, embedding_model_id, embedding_dimension): + vector_db_id = "test_query_generation_db" + + client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_id, + embedding_model=embedding_model_id, + embedding_dimension=embedding_dimension, + ) + + available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] + actual_vector_db_id = available_vector_dbs[0] + + documents = [ + Document( + document_id="ai-doc", + content="Artificial intelligence and machine learning are transforming technology.", + metadata={"category": "AI"}, + ), + Document( + document_id="banana-doc", + content="Don't bring a banana to a knife fight.", + metadata={"category": "wisdom"}, + ), + ] + + client_with_empty_registry.tool_runtime.rag_tool.insert( + documents=documents, + vector_db_id=actual_vector_db_id, + chunk_size_in_tokens=256, + ) + + response = client_with_empty_registry.tool_runtime.rag_tool.query( + vector_db_ids=[actual_vector_db_id], + content="Tell me about AI", + ) + + assert_valid_text_response(response) + content_text = " ".join([chunk.text for chunk in response.content]).lower() + assert "artificial intelligence" in content_text or "machine learning" in content_text + + +def test_rag_tool_pdf_data_url_handling(client_with_empty_registry, embedding_model_id, embedding_dimension): + vector_db_id = "test_pdf_data_url_db" + + client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_id, + embedding_model=embedding_model_id, + embedding_dimension=embedding_dimension, + ) + + available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] + actual_vector_db_id = available_vector_dbs[0] + + sample_pdf = b"%PDF-1.3\n3 0 obj\n<>\nendobj\n4 0 obj\n<>\nstream\nx\x9c\x15\xcc1\x0e\x820\x18@\xe1\x9dS\xbcM]jk$\xd5\xd5(\x83!\x86\xa1\x17\xf8\xa3\xa5`LIh+\xd7W\xc6\xf7\r\xef\xc0\xbd\xd2\xaa\xb6,\xd5\xc5\xb1o\x0c\xa6VZ\xe3znn%\xf3o\xab\xb1\xe7\xa3:Y\xdc\x8bm\xeb\xf3&1\xc8\xd7\xd3\x97\xc82\xe6\x81\x87\xe42\xcb\x87Vb(\x12<\xdd<=}Jc\x0cL\x91\xee\xda$\xb5\xc3\xbd\xd7\xe9\x0f\x8d\x97 $\nendstream\nendobj\n1 0 obj\n<>\nendobj\n5 0 obj\n<>\nendobj\n2 0 obj\n<<\n/ProcSet [/PDF /Text /ImageB /ImageC /ImageI]\n/Font <<\n/F1 5 0 R\n>>\n/XObject <<\n>>\n>>\nendobj\n6 0 obj\n<<\n/Producer (PyFPDF 1.7.2 http://pyfpdf.googlecode.com/)\n/Title (This is a sample title.)\n/Author (Llama Stack Developers)\n/CreationDate (D:20250312165548)\n>>\nendobj\n7 0 obj\n<<\n/Type /Catalog\n/Pages 1 0 R\n/OpenAction [3 0 R /FitH null]\n/PageLayout /OneColumn\n>>\nendobj\nxref\n0 8\n0000000000 65535 f \n0000000272 00000 n \n0000000455 00000 n \n0000000009 00000 n \n0000000087 00000 n \n0000000359 00000 n \n0000000559 00000 n \n0000000734 00000 n \ntrailer\n<<\n/Size 8\n/Root 7 0 R\n/Info 6 0 R\n>>\nstartxref\n837\n%%EOF\n" + + import base64 + + pdf_base64 = base64.b64encode(sample_pdf).decode("utf-8") + pdf_data_url = f"data:application/pdf;base64,{pdf_base64}" + + documents = [ + Document( + document_id="test-pdf-data-url", + content=pdf_data_url, + metadata={"type": "pdf", "source": "data_url"}, + ), + ] + + client_with_empty_registry.tool_runtime.rag_tool.insert( + documents=documents, + vector_db_id=actual_vector_db_id, + chunk_size_in_tokens=256, + ) + + files_list = client_with_empty_registry.files.list() + assert len(files_list.data) >= 1, "PDF should have been uploaded to Files API" + + pdf_file = None + for file in files_list.data: + if file.filename and "test-pdf-data-url" in file.filename: + pdf_file = file + break + + assert pdf_file is not None, "PDF file should be found in Files API" + assert pdf_file.bytes == len(sample_pdf), f"File size should match original PDF ({len(sample_pdf)} bytes)" + + file_content = client_with_empty_registry.files.retrieve_content(pdf_file.id) + assert file_content.startswith(b"%PDF-"), "Retrieved file should be a valid PDF" + + vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store( + vector_store_id=actual_vector_db_id + ) + assert len(vector_store_files.data) >= 1, "PDF should be attached to vector store" + + response = client_with_empty_registry.tool_runtime.rag_tool.query( + vector_db_ids=[actual_vector_db_id], + content="sample title", + ) + + assert_valid_text_response(response) + content_text = " ".join([chunk.text for chunk in response.content]).lower() + assert "sample title" in content_text or "title" in content_text diff --git a/tests/unit/providers/utils/memory/test_vector_store.py b/tests/unit/providers/utils/memory/test_vector_store.py index 90b229262..590bdd1d2 100644 --- a/tests/unit/providers/utils/memory/test_vector_store.py +++ b/tests/unit/providers/utils/memory/test_vector_store.py @@ -178,3 +178,41 @@ def test_content_from_data_and_mime_type_both_encodings_fail(): # Should raise an exception instead of returning empty string with pytest.raises(UnicodeDecodeError): content_from_data_and_mime_type(data, mime_type) + + +async def test_memory_tool_error_handling(): + """Test that memory tool handles various failures gracefully without crashing.""" + from llama_stack.providers.inline.tool_runtime.rag.config import RagToolRuntimeConfig + from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl + + config = RagToolRuntimeConfig() + memory_tool = MemoryToolRuntimeImpl( + config=config, + vector_io_api=AsyncMock(), + inference_api=AsyncMock(), + files_api=AsyncMock(), + ) + + docs = [ + RAGDocument(document_id="good_doc", content="Good content", metadata={}), + RAGDocument(document_id="bad_url_doc", content=URL(uri="https://bad.url"), metadata={}), + RAGDocument(document_id="another_good_doc", content="Another good content", metadata={}), + ] + + mock_file1 = MagicMock() + mock_file1.id = "file_good1" + mock_file2 = MagicMock() + mock_file2.id = "file_good2" + memory_tool.files_api.openai_upload_file.side_effect = [mock_file1, mock_file2] + + with patch("httpx.AsyncClient") as mock_client: + mock_instance = AsyncMock() + mock_instance.get.side_effect = Exception("Bad URL") + mock_client.return_value.__aenter__.return_value = mock_instance + + # won't raise exception despite one document failing + await memory_tool.insert(docs, "vector_store_123") + + # processed 2 documents successfully, skipped 1 + assert memory_tool.files_api.openai_upload_file.call_count == 2 + assert memory_tool.vector_io_api.openai_attach_file_to_vector_store.call_count == 2