From 7cd1c2c238969cb41ddb3d87b7dcfe5331350be1 Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Sat, 6 Sep 2025 07:26:34 -0600 Subject: [PATCH] feat: Updating Rag Tool to use Files API and Vector Stores API (#3344) --- docs/source/getting_started/demo_script.py | 5 +- .../inline/tool_runtime/rag/__init__.py | 2 +- .../inline/tool_runtime/rag/memory.py | 74 ++++++++++++++----- .../providers/registry/tool_runtime.py | 2 +- .../integration/tool_runtime/test_rag_tool.py | 41 ++++++---- tests/unit/rag/test_rag_query.py | 8 +- 6 files changed, 93 insertions(+), 39 deletions(-) diff --git a/docs/source/getting_started/demo_script.py b/docs/source/getting_started/demo_script.py index 777fc78c2..2ea67739f 100644 --- a/docs/source/getting_started/demo_script.py +++ b/docs/source/getting_started/demo_script.py @@ -18,12 +18,13 @@ embedding_model_id = ( ).identifier embedding_dimension = em.metadata["embedding_dimension"] -_ = client.vector_dbs.register( +vector_db = client.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, provider_id="faiss", ) +vector_db_id = vector_db.identifier source = "https://www.paulgraham.com/greatwork.html" print("rag_tool> Ingesting document:", source) document = RAGDocument( @@ -35,7 +36,7 @@ document = RAGDocument( client.tool_runtime.rag_tool.insert( documents=[document], vector_db_id=vector_db_id, - chunk_size_in_tokens=50, + chunk_size_in_tokens=100, ) agent = Agent( client, diff --git a/llama_stack/providers/inline/tool_runtime/rag/__init__.py b/llama_stack/providers/inline/tool_runtime/rag/__init__.py index f9a6e5c55..f9a7e7b89 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/__init__.py +++ b/llama_stack/providers/inline/tool_runtime/rag/__init__.py @@ -14,6 +14,6 @@ from .config import RagToolRuntimeConfig async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]): from .memory import MemoryToolRuntimeImpl - impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) + impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index a1543457b..cb526e8ee 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -5,10 +5,15 @@ # the root directory of this source tree. import asyncio +import base64 +import io +import mimetypes import secrets import string from typing import Any +import httpx +from fastapi import UploadFile from pydantic import TypeAdapter from llama_stack.apis.common.content_types import ( @@ -17,6 +22,7 @@ from llama_stack.apis.common.content_types import ( InterleavedContentItem, TextContentItem, ) +from llama_stack.apis.files import Files, OpenAIFilePurpose from llama_stack.apis.inference import Inference from llama_stack.apis.tools import ( ListToolDefsResponse, @@ -30,13 +36,18 @@ from llama_stack.apis.tools import ( ToolParameter, ToolRuntime, ) -from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO +from llama_stack.apis.vector_io import ( + QueryChunksResponse, + VectorIO, + VectorStoreChunkingStrategyStatic, + VectorStoreChunkingStrategyStaticConfig, +) from llama_stack.log import get_logger from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.memory.vector_store import ( content_from_doc, - make_overlapped_chunks, + parse_data_url, ) from .config import RagToolRuntimeConfig @@ -55,10 +66,12 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti config: RagToolRuntimeConfig, vector_io_api: VectorIO, inference_api: Inference, + files_api: Files, ): self.config = config self.vector_io_api = vector_io_api self.inference_api = inference_api + self.files_api = files_api async def initialize(self): pass @@ -78,27 +91,50 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti vector_db_id: str, chunk_size_in_tokens: int = 512, ) -> None: - chunks = [] + if not documents: + return + for doc in documents: - content = await content_from_doc(doc) - # TODO: we should add enrichment here as URLs won't be added to the metadata by default - chunks.extend( - make_overlapped_chunks( - doc.document_id, - content, - chunk_size_in_tokens, - chunk_size_in_tokens // 4, - doc.metadata, + if isinstance(doc.content, URL): + if doc.content.uri.startswith("data:"): + parts = parse_data_url(doc.content.uri) + file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode() + mime_type = parts["mimetype"] + else: + async with httpx.AsyncClient() as client: + response = await client.get(doc.content.uri) + file_data = response.content + mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream") + else: + content_str = await content_from_doc(doc) + file_data = content_str.encode("utf-8") + mime_type = doc.mime_type or "text/plain" + + file_extension = mimetypes.guess_extension(mime_type) or ".txt" + filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}") + + file_obj = io.BytesIO(file_data) + file_obj.name = filename + + upload_file = UploadFile(file=file_obj, filename=filename) + + created_file = await self.files_api.openai_upload_file( + file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS + ) + + chunking_strategy = VectorStoreChunkingStrategyStatic( + static=VectorStoreChunkingStrategyStaticConfig( + max_chunk_size_tokens=chunk_size_in_tokens, + chunk_overlap_tokens=chunk_size_in_tokens // 4, ) ) - if not chunks: - return - - await self.vector_io_api.insert_chunks( - chunks=chunks, - vector_db_id=vector_db_id, - ) + await self.vector_io_api.openai_attach_file_to_vector_store( + vector_store_id=vector_db_id, + file_id=created_file.id, + attributes=doc.metadata, + chunking_strategy=chunking_strategy, + ) async def query( self, diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 661851443..5a58fa7af 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -32,7 +32,7 @@ def available_providers() -> list[ProviderSpec]: ], module="llama_stack.providers.inline.tool_runtime.rag", config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig", - api_dependencies=[Api.vector_io, Api.inference], + api_dependencies=[Api.vector_io, Api.inference, Api.files], description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.", ), remote_provider_spec( diff --git a/tests/integration/tool_runtime/test_rag_tool.py b/tests/integration/tool_runtime/test_rag_tool.py index 2affe2a2d..b208500d8 100644 --- a/tests/integration/tool_runtime/test_rag_tool.py +++ b/tests/integration/tool_runtime/test_rag_tool.py @@ -17,10 +17,14 @@ def client_with_empty_registry(client_with_models): client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id) clear_registry() + + try: + client_with_models.toolgroups.register(toolgroup_id="builtin::rag", provider_id="rag-runtime") + except Exception: + pass + yield client_with_models - # you must clean after the last test if you were running tests against - # a stateful server instance clear_registry() @@ -66,12 +70,13 @@ def assert_valid_text_response(response): def test_vector_db_insert_inline_and_query( client_with_empty_registry, sample_documents, embedding_model_id, embedding_dimension ): - vector_db_id = "test_vector_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_vector_db" + vector_db = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) + vector_db_id = vector_db.identifier client_with_empty_registry.tool_runtime.rag_tool.insert( documents=sample_documents, @@ -134,7 +139,11 @@ def test_vector_db_insert_from_url_and_query( # list to check memory bank is successfully registered available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] - assert vector_db_id in available_vector_dbs + # VectorDB is being migrated to VectorStore, so the ID will be different + # Just check that at least one vector DB was registered + assert len(available_vector_dbs) > 0 + # Use the actual registered vector_db_id for subsequent operations + actual_vector_db_id = available_vector_dbs[0] urls = [ "memory_optimizations.rst", @@ -153,13 +162,13 @@ def test_vector_db_insert_from_url_and_query( client_with_empty_registry.tool_runtime.rag_tool.insert( documents=documents, - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, chunk_size_in_tokens=512, ) # Query for the name of method response1 = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query="What's the name of the fine-tunning method used?", ) assert_valid_chunk_response(response1) @@ -167,7 +176,7 @@ def test_vector_db_insert_from_url_and_query( # Query for the name of model response2 = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query="Which Llama model is mentioned?", ) assert_valid_chunk_response(response2) @@ -187,7 +196,11 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i ) available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] - assert vector_db_id in available_vector_dbs + # VectorDB is being migrated to VectorStore, so the ID will be different + # Just check that at least one vector DB was registered + assert len(available_vector_dbs) > 0 + # Use the actual registered vector_db_id for subsequent operations + actual_vector_db_id = available_vector_dbs[0] urls = [ "memory_optimizations.rst", @@ -206,19 +219,19 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i client_with_empty_registry.tool_runtime.rag_tool.insert( documents=documents, - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, chunk_size_in_tokens=512, ) response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[vector_db_id], + vector_db_ids=[actual_vector_db_id], content="What is the name of the method used for fine-tuning?", ) assert_valid_text_response(response_with_metadata) assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content) response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[vector_db_id], + vector_db_ids=[actual_vector_db_id], content="What is the name of the method used for fine-tuning?", query_config={ "include_metadata_in_content": True, @@ -230,7 +243,7 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i with pytest.raises((ValueError, BadRequestError)): client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[vector_db_id], + vector_db_ids=[actual_vector_db_id], content="What is the name of the method used for fine-tuning?", query_config={ "chunk_template": "This should raise a ValueError because it is missing the proper template variables", diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py index 05ccecb99..d18d90716 100644 --- a/tests/unit/rag/test_rag_query.py +++ b/tests/unit/rag/test_rag_query.py @@ -19,12 +19,16 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti class TestRagQuery: async def test_query_raises_on_empty_vector_db_ids(self): - rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock()) + rag_tool = MemoryToolRuntimeImpl( + config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock() + ) with pytest.raises(ValueError): await rag_tool.query(content=MagicMock(), vector_db_ids=[]) async def test_query_chunk_metadata_handling(self): - rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock()) + rag_tool = MemoryToolRuntimeImpl( + config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock() + ) content = "test query content" vector_db_ids = ["db1"]