diff --git a/llama_stack/apis/preprocessing/preprocessing.py b/llama_stack/apis/preprocessing/preprocessing.py index 9e472cc1d..3d6376e91 100644 --- a/llama_stack/apis/preprocessing/preprocessing.py +++ b/llama_stack/apis/preprocessing/preprocessing.py @@ -35,7 +35,7 @@ class PreprocessingDataFormat(Enum): @json_schema_type -class PreprocessingInput(BaseModel): +class PreprocessorInput(BaseModel): preprocessor_input_id: str preprocessor_input_type: Optional[PreprocessingDataType] = None preprocessor_input_format: Optional[PreprocessingDataFormat] = None @@ -46,7 +46,16 @@ PreprocessorOptions = Dict[str, Any] @json_schema_type -class PreprocessingResponse(BaseModel): +class PreprocessorChainElement(BaseModel): + preprocessor_id: str + options: Optional[PreprocessorOptions] = None + + +PreprocessorChain = List[PreprocessorChainElement] + + +@json_schema_type +class PreprocessorResponse(BaseModel): status: bool results: Optional[List[str | InterleavedContent | Chunk]] = None @@ -59,10 +68,21 @@ class PreprocessorStore(Protocol): class Preprocessing(Protocol): preprocessor_store: PreprocessorStore + input_types: List[PreprocessingDataType] + output_types: List[PreprocessingDataType] + @webmethod(route="/preprocess", method="POST") async def preprocess( self, preprocessor_id: str, - preprocessor_inputs: List[PreprocessingInput], + preprocessor_inputs: List[PreprocessorInput], options: Optional[PreprocessorOptions] = None, - ) -> PreprocessingResponse: ... + ) -> PreprocessorResponse: ... + + @webmethod(route="/chain_preprocess", method="POST") + async def chain_preprocess( + self, + preprocessors: PreprocessorChain, + preprocessor_inputs: List[PreprocessorInput], + is_rag_chain: Optional[bool] = False, + ) -> PreprocessorResponse: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 19a2e6249..2c6de8378 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -34,7 +34,13 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import ModelType -from llama_stack.apis.preprocessing import Preprocessing, PreprocessingInput, PreprocessingResponse, PreprocessorOptions +from llama_stack.apis.preprocessing import ( + Preprocessing, + PreprocessorChain, + PreprocessorInput, + PreprocessorOptions, + PreprocessorResponse, +) from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.scoring import ( ScoreBatchResponse, @@ -52,6 +58,7 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO +from llama_stack.distribution.utils.chain import execute_preprocessor_chain from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format @@ -501,11 +508,20 @@ class PreprocessingRouter(Preprocessing): async def preprocess( self, preprocessor_id: str, - preprocessor_inputs: List[PreprocessingInput], + preprocessor_inputs: List[PreprocessorInput], options: Optional[PreprocessorOptions] = None, - ) -> PreprocessingResponse: + ) -> PreprocessorResponse: return await self.routing_table.get_provider_impl(preprocessor_id).preprocess( preprocessor_id=preprocessor_id, preprocessor_inputs=preprocessor_inputs, options=options, ) + + async def chain_preprocess( + self, + preprocessors: PreprocessorChain, + preprocessor_inputs: List[PreprocessorInput], + is_rag_chain: Optional[bool] = False, + ) -> PreprocessorResponse: + preprocessor_impls = [self.routing_table.get_provider_impl(p.preprocessor_id) for p in preprocessors] + return await execute_preprocessor_chain(preprocessors, preprocessor_impls, preprocessor_inputs, is_rag_chain) diff --git a/llama_stack/distribution/utils/chain.py b/llama_stack/distribution/utils/chain.py new file mode 100644 index 000000000..118e13efc --- /dev/null +++ b/llama_stack/distribution/utils/chain.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import logging +from itertools import pairwise +from typing import List + +from llama_stack.apis.preprocessing import ( + Preprocessing, + PreprocessingDataType, + PreprocessorChain, + PreprocessorInput, + PreprocessorResponse, +) + +log = logging.getLogger(__name__) + + +def validate_chain(chain_impls: List[Preprocessing], is_rag_chain: bool) -> bool: + if len(chain_impls) == 0: + log.error("Empty preprocessing chain was provided") + return False + + if is_rag_chain and PreprocessingDataType.chunks not in chain_impls[-1].output_types: + log.error( + f"RAG preprocessing chain must end with a chunk-producing preprocessor, but the last preprocessor in the provided chain only supports {chain_impls[-1].output_types}" + ) + return False + + for current_preprocessor, next_preprocessor in pairwise(chain_impls): + current_output_types = current_preprocessor.output_types + next_input_types = next_preprocessor.input_types + + if len(list(set(current_output_types) & set(next_input_types))) == 0: + log.error( + f"Incompatible input ({current_output_types}) and output({next_input_types}) preprocessor data types" + ) + return False + + return True + + +async def execute_preprocessor_chain( + preprocessor_chain: PreprocessorChain, + preprocessor_chain_impls: List[Preprocessing], + preprocessor_inputs: List[PreprocessorInput], + is_rag_chain: bool, +) -> PreprocessorResponse: + if not validate_chain(preprocessor_chain_impls, is_rag_chain): + return PreprocessorResponse(status=False, results=[]) + + current_inputs = preprocessor_inputs + current_outputs = [] + + # TODO: replace with a parallel implementation + for i, current_params in enumerate(preprocessor_chain): + current_impl = preprocessor_chain_impls[i] + response = await current_impl.preprocess( + preprocessor_id=current_params.preprocessor_id, + preprocessor_inputs=current_inputs, + options=current_params.options, + ) + if not response.status: + log.error(f"Preprocessor {current_params.preprocessor_id} returned an error") + return PreprocessorResponse(status=False, results=[]) + current_outputs = response.results + current_inputs = current_outputs + + return PreprocessorResponse(status=True, results=current_outputs) diff --git a/llama_stack/providers/inline/preprocessing/basic/basic.py b/llama_stack/providers/inline/preprocessing/basic/basic.py index df2d138cd..5a51ab0b6 100644 --- a/llama_stack/providers/inline/preprocessing/basic/basic.py +++ b/llama_stack/providers/inline/preprocessing/basic/basic.py @@ -14,10 +14,11 @@ from llama_stack.apis.preprocessing import ( Preprocessing, PreprocessingDataFormat, PreprocessingDataType, - PreprocessingInput, - PreprocessingResponse, Preprocessor, + PreprocessorChain, + PreprocessorInput, PreprocessorOptions, + PreprocessorResponse, ) from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate from llama_stack.providers.inline.preprocessing.basic.config import InlineBasicPreprocessorConfig @@ -29,14 +30,14 @@ log = logging.getLogger(__name__) class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): # this preprocessor can either receive documents (text or binary) or document URIs - INPUT_TYPES = [ + input_types = [ PreprocessingDataType.binary_document, PreprocessingDataType.raw_text_document, PreprocessingDataType.document_uri, ] # this preprocessor optionally retrieves the documents and converts them into plain text - OUTPUT_TYPES = [PreprocessingDataType.raw_text_document] + output_types = [PreprocessingDataType.raw_text_document] URL_VALIDATION_PATTERN = re.compile("^(https?://|file://|data:)") @@ -54,9 +55,9 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): async def preprocess( self, preprocessor_id: str, - preprocessor_inputs: List[PreprocessingInput], + preprocessor_inputs: List[PreprocessorInput], options: Optional[PreprocessorOptions] = None, - ) -> PreprocessingResponse: + ) -> PreprocessorResponse: results = [] for inp in preprocessor_inputs: @@ -87,10 +88,18 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): results.append(document) - return PreprocessingResponse(status=True, results=results) + return PreprocessorResponse(status=True, results=results) + + async def chain_preprocess( + self, + preprocessors: PreprocessorChain, + preprocessor_inputs: List[PreprocessorInput], + is_rag_chain: Optional[bool] = False, + ) -> PreprocessorResponse: + return await self.preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs) @staticmethod - async def _resolve_input_type(preprocessor_input: PreprocessingInput) -> PreprocessingDataType: + async def _resolve_input_type(preprocessor_input: PreprocessorInput) -> PreprocessingDataType: if preprocessor_input.preprocessor_input_type is not None: return preprocessor_input.preprocessor_input_type @@ -104,7 +113,7 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): return PreprocessingDataType.raw_text_document @staticmethod - async def _fetch_document(preprocessor_input: PreprocessingInput) -> str | None: + async def _fetch_document(preprocessor_input: PreprocessorInput) -> str | None: if isinstance(preprocessor_input.path_or_content, str): url = preprocessor_input.path_or_content if not InclineBasicPreprocessorImpl.URL_VALIDATION_PATTERN.match(url): @@ -125,7 +134,3 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): r = await client.get(url) return r.content if preprocessor_input.preprocessor_input_format == PreprocessingDataFormat.pdf else r.text - - @staticmethod - def is_pdf(preprocessor_input: PreprocessingInput): - return diff --git a/llama_stack/providers/inline/preprocessing/docling/docling.py b/llama_stack/providers/inline/preprocessing/docling/docling.py index 90b5fd912..3fdc29b0a 100644 --- a/llama_stack/providers/inline/preprocessing/docling/docling.py +++ b/llama_stack/providers/inline/preprocessing/docling/docling.py @@ -13,10 +13,11 @@ from llama_stack.apis.common.content_types import URL from llama_stack.apis.preprocessing import ( Preprocessing, PreprocessingDataType, - PreprocessingInput, - PreprocessingResponse, Preprocessor, + PreprocessorChain, + PreprocessorInput, PreprocessorOptions, + PreprocessorResponse, ) from llama_stack.apis.vector_io import Chunk from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate @@ -27,10 +28,10 @@ log = logging.getLogger(__name__) class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): # this preprocessor receives URLs / paths to documents as input - INPUT_TYPES = [PreprocessingDataType.document_uri] + input_types = [PreprocessingDataType.document_uri] # this preprocessor either only converts the documents into a text format, or also chunks them - OUTPUT_TYPES = [PreprocessingDataType.raw_text_document, PreprocessingDataType.chunks] + output_types = [PreprocessingDataType.raw_text_document, PreprocessingDataType.chunks] def __init__(self, config: InlineDoclingConfig) -> None: self.config = config @@ -50,9 +51,9 @@ class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate async def preprocess( self, preprocessor_id: str, - preprocessor_inputs: List[PreprocessingInput], + preprocessor_inputs: List[PreprocessorInput], options: Optional[PreprocessorOptions] = None, - ) -> PreprocessingResponse: + ) -> PreprocessorResponse: results = [] for inp in preprocessor_inputs: @@ -74,4 +75,12 @@ class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate result = converted_document.export_to_markdown() results.append(result) - return PreprocessingResponse(status=True, results=results) + return PreprocessorResponse(status=True, results=results) + + async def chain_preprocess( + self, + preprocessors: PreprocessorChain, + preprocessor_inputs: List[PreprocessorInput], + is_rag_chain: Optional[bool] = False, + ) -> PreprocessorResponse: + return await self.preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs) diff --git a/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py b/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py index 542178db3..2822c3f15 100644 --- a/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py +++ b/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py @@ -12,10 +12,11 @@ from llama_models.llama3.api import Tokenizer from llama_stack.apis.preprocessing import ( Preprocessing, PreprocessingDataType, - PreprocessingInput, - PreprocessingResponse, Preprocessor, + PreprocessorChain, + PreprocessorInput, PreprocessorOptions, + PreprocessorResponse, ) from llama_stack.apis.vector_io import Chunk from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate @@ -31,8 +32,8 @@ class SimpleChunkingOptions(Enum): class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate): # this preprocessor receives plain text and returns chunks - INPUT_TYPES = [PreprocessingDataType.raw_text_document] - OUTPUT_TYPES = [PreprocessingDataType.chunks] + input_types = [PreprocessingDataType.raw_text_document] + output_types = [PreprocessingDataType.chunks] def __init__(self, config: InclineSimpleChunkingConfig) -> None: self.config = config @@ -48,9 +49,9 @@ class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate): async def preprocess( self, preprocessor_id: str, - preprocessor_inputs: List[PreprocessingInput], + preprocessor_inputs: List[PreprocessorInput], options: Optional[PreprocessorOptions] = None, - ) -> PreprocessingResponse: + ) -> PreprocessorResponse: chunks = [] window_len, overlap_len = self._resolve_chunk_size_params(options) @@ -61,7 +62,15 @@ class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate): ) chunks.extend(new_chunks) - return PreprocessingResponse(status=True, results=chunks) + return PreprocessorResponse(status=True, results=chunks) + + async def chain_preprocess( + self, + preprocessors: PreprocessorChain, + preprocessor_inputs: List[PreprocessorInput], + is_rag_chain: Optional[bool] = False, + ) -> PreprocessorResponse: + return await self.preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs) def _resolve_chunk_size_params(self, options: PreprocessorOptions) -> Tuple[int, int]: window_len = (options or {}).get( diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index a0e0573a7..0bf43eb89 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -18,7 +18,12 @@ from llama_stack.apis.common.content_types import ( TextContentItem, ) from llama_stack.apis.inference import Inference -from llama_stack.apis.preprocessing import Preprocessing, PreprocessingDataFormat, PreprocessingInput +from llama_stack.apis.preprocessing import ( + Preprocessing, + PreprocessingDataFormat, + PreprocessorChainElement, + PreprocessorInput, +) from llama_stack.apis.tools import ( RAGDocument, RAGQueryConfig, @@ -67,17 +72,16 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): vector_db_id: str, chunk_size_in_tokens: int = 512, ) -> None: - preprocessing_inputs = [self._rag_document_to_preprocessing_input(d) for d in documents] - - conversion_response = await self.preprocessing_api.preprocess( - preprocessor_id="builtin::basic", preprocessor_inputs=preprocessing_inputs + preprocessor_inputs = [self._rag_document_to_preprocessor_input(d) for d in documents] + preprocessor_chain = [ + PreprocessorChainElement(preprocessor_id="builtin::basic"), + PreprocessorChainElement(preprocessor_id="builtin::chunking"), + ] + preprocessor_response = await self.preprocessing_api.chain_preprocess( + preprocessors=preprocessor_chain, preprocessor_inputs=preprocessor_inputs ) - converted_inputs = conversion_response.results - chunking_response = await self.preprocessing_api.preprocess( - preprocessor_id="builtin::chunking", preprocessor_inputs=converted_inputs - ) - chunks = chunking_response.results + chunks = preprocessor_response.results if not chunks: return @@ -197,13 +201,13 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): ) @staticmethod - def _rag_document_to_preprocessing_input(document: RAGDocument) -> PreprocessingInput: + def _rag_document_to_preprocessor_input(document: RAGDocument) -> PreprocessorInput: if document.mime_type == "application/pdf": preprocessor_input_format = PreprocessingDataFormat.pdf else: preprocessor_input_format = None - return PreprocessingInput( + return PreprocessorInput( preprocessor_input_id=document.document_id, preprocessor_input_format=preprocessor_input_format, path_or_content=document.content,