From 6cbc298edbed9ae719c2efdaa9cf3ef4364fe8e7 Mon Sep 17 00:00:00 2001 From: ilya-kolchinsky Date: Thu, 6 Mar 2025 14:22:19 +0100 Subject: [PATCH] Added the preprocessing chain parameter to the RAG tool insert API. --- llama_stack/apis/tools/rag_tool.py | 2 ++ llama_stack/distribution/routers/routers.py | 3 ++- .../providers/inline/tool_runtime/rag/memory.py | 10 +++++++++- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index 2b9ef10d8..7065142d8 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, Field from typing_extensions import Annotated, Protocol, runtime_checkable from llama_stack.apis.common.content_types import URL, InterleavedContent +from llama_stack.apis.preprocessing import PreprocessorChain from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, register_schema, webmethod @@ -79,6 +80,7 @@ class RAGToolRuntime(Protocol): documents: List[RAGDocument], vector_db_id: str, chunk_size_in_tokens: int = 512, + preprocessor_chain: Optional[PreprocessorChain] = None, ) -> None: """Index documents so they can be used by the RAG system""" ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 83752abd3..460de5c47 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -458,9 +458,10 @@ class ToolRuntimeRouter(ToolRuntime): documents: List[RAGDocument], vector_db_id: str, chunk_size_in_tokens: int = 512, + preprocessor_chain: Optional[PreprocessorChain] = None, ) -> None: return await self.routing_table.get_provider_impl("insert_into_memory").insert( - documents, vector_db_id, chunk_size_in_tokens + documents, vector_db_id, chunk_size_in_tokens, preprocessor_chain ) def __init__( diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 6a639a44b..6a81b9f16 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -22,6 +22,7 @@ from llama_stack.apis.preprocessing import ( Preprocessing, PreprocessingDataFormat, PreprocessingDataType, + PreprocessorChain, PreprocessorChainElement, PreprocessorInput, ) @@ -49,6 +50,11 @@ def make_random_string(length: int = 8): class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): + DEFAULT_PREPROCESSING_CHAIN = [ + PreprocessorChainElement(preprocessor_id="builtin::basic"), + PreprocessorChainElement(preprocessor_id="builtin::chunking"), + ] + def __init__( self, config: RagToolRuntimeConfig, @@ -72,6 +78,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): documents: List[RAGDocument], vector_db_id: str, chunk_size_in_tokens: int = 512, + preprocessor_chain: Optional[PreprocessorChain] = None, ) -> None: preprocessor_inputs = [self._rag_document_to_preprocessor_input(d) for d in documents] preprocessor_chain = [ @@ -79,7 +86,8 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): PreprocessorChainElement(preprocessor_id="builtin::chunking"), ] preprocessor_response = await self.preprocessing_api.chain_preprocess( - preprocessors=preprocessor_chain, preprocessor_inputs=preprocessor_inputs + preprocessors=preprocessor_chain or self.DEFAULT_PREPROCESSING_CHAIN, + preprocessor_inputs=preprocessor_inputs, ) if not preprocessor_response.success: