diff --git a/llama_stack/apis/preprocessing/preprocessing.py b/llama_stack/apis/preprocessing/preprocessing.py index 3d6376e91..f5c34becd 100644 --- a/llama_stack/apis/preprocessing/preprocessing.py +++ b/llama_stack/apis/preprocessing/preprocessing.py @@ -56,7 +56,8 @@ PreprocessorChain = List[PreprocessorChainElement] @json_schema_type class PreprocessorResponse(BaseModel): - status: bool + success: bool + preprocessor_output_type: PreprocessingDataType results: Optional[List[str | InterleavedContent | Chunk]] = None @@ -84,5 +85,4 @@ class Preprocessing(Protocol): self, preprocessors: PreprocessorChain, preprocessor_inputs: List[PreprocessorInput], - is_rag_chain: Optional[bool] = False, ) -> PreprocessorResponse: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 2c6de8378..83752abd3 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -521,7 +521,6 @@ class PreprocessingRouter(Preprocessing): self, preprocessors: PreprocessorChain, preprocessor_inputs: List[PreprocessorInput], - is_rag_chain: Optional[bool] = False, ) -> PreprocessorResponse: preprocessor_impls = [self.routing_table.get_provider_impl(p.preprocessor_id) for p in preprocessors] - return await execute_preprocessor_chain(preprocessors, preprocessor_impls, preprocessor_inputs, is_rag_chain) + return await execute_preprocessor_chain(preprocessors, preprocessor_impls, preprocessor_inputs) diff --git a/llama_stack/distribution/utils/chain.py b/llama_stack/distribution/utils/chain.py index 118e13efc..e31e12410 100644 --- a/llama_stack/distribution/utils/chain.py +++ b/llama_stack/distribution/utils/chain.py @@ -9,7 +9,6 @@ from typing import List from llama_stack.apis.preprocessing import ( Preprocessing, - PreprocessingDataType, PreprocessorChain, PreprocessorInput, PreprocessorResponse, @@ -18,17 +17,11 @@ from llama_stack.apis.preprocessing import ( log = logging.getLogger(__name__) -def validate_chain(chain_impls: List[Preprocessing], is_rag_chain: bool) -> bool: +def validate_chain(chain_impls: List[Preprocessing]) -> bool: if len(chain_impls) == 0: log.error("Empty preprocessing chain was provided") return False - if is_rag_chain and PreprocessingDataType.chunks not in chain_impls[-1].output_types: - log.error( - f"RAG preprocessing chain must end with a chunk-producing preprocessor, but the last preprocessor in the provided chain only supports {chain_impls[-1].output_types}" - ) - return False - for current_preprocessor, next_preprocessor in pairwise(chain_impls): current_output_types = current_preprocessor.output_types next_input_types = next_preprocessor.input_types @@ -46,13 +39,13 @@ async def execute_preprocessor_chain( preprocessor_chain: PreprocessorChain, preprocessor_chain_impls: List[Preprocessing], preprocessor_inputs: List[PreprocessorInput], - is_rag_chain: bool, ) -> PreprocessorResponse: - if not validate_chain(preprocessor_chain_impls, is_rag_chain): - return PreprocessorResponse(status=False, results=[]) + if not validate_chain(preprocessor_chain_impls): + return PreprocessorResponse(success=False, results=[]) current_inputs = preprocessor_inputs current_outputs = [] + current_result_type = None # TODO: replace with a parallel implementation for i, current_params in enumerate(preprocessor_chain): @@ -62,10 +55,13 @@ async def execute_preprocessor_chain( preprocessor_inputs=current_inputs, options=current_params.options, ) - if not response.status: + if not response.success: log.error(f"Preprocessor {current_params.preprocessor_id} returned an error") - return PreprocessorResponse(status=False, results=[]) + return PreprocessorResponse( + success=False, preprocessor_output_type=response.preprocessor_output_type, results=[] + ) current_outputs = response.results current_inputs = current_outputs + current_result_type = response.preprocessor_output_type - return PreprocessorResponse(status=True, results=current_outputs) + return PreprocessorResponse(success=True, preprocessor_output_type=current_result_type, results=current_outputs) diff --git a/llama_stack/providers/inline/preprocessing/basic/basic.py b/llama_stack/providers/inline/preprocessing/basic/basic.py index 5a51ab0b6..ae7d03b07 100644 --- a/llama_stack/providers/inline/preprocessing/basic/basic.py +++ b/llama_stack/providers/inline/preprocessing/basic/basic.py @@ -88,13 +88,14 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): results.append(document) - return PreprocessorResponse(status=True, results=results) + return PreprocessorResponse( + success=True, preprocessor_output_type=PreprocessingDataType.raw_text_document, results=results + ) async def chain_preprocess( self, preprocessors: PreprocessorChain, preprocessor_inputs: List[PreprocessorInput], - is_rag_chain: Optional[bool] = False, ) -> PreprocessorResponse: return await self.preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs) diff --git a/llama_stack/providers/inline/preprocessing/docling/docling.py b/llama_stack/providers/inline/preprocessing/docling/docling.py index 3fdc29b0a..3492a70c1 100644 --- a/llama_stack/providers/inline/preprocessing/docling/docling.py +++ b/llama_stack/providers/inline/preprocessing/docling/docling.py @@ -75,12 +75,14 @@ class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate result = converted_document.export_to_markdown() results.append(result) - return PreprocessorResponse(status=True, results=results) + preprocessor_output_type = ( + PreprocessingDataType.chunks if self.config.chunk else PreprocessingDataType.raw_text_document + ) + return PreprocessorResponse(success=True, preprocessor_output_type=preprocessor_output_type, results=results) async def chain_preprocess( self, preprocessors: PreprocessorChain, preprocessor_inputs: List[PreprocessorInput], - is_rag_chain: Optional[bool] = False, ) -> PreprocessorResponse: return await self.preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs) diff --git a/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py b/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py index 2822c3f15..5aea70079 100644 --- a/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py +++ b/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py @@ -62,13 +62,12 @@ class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate): ) chunks.extend(new_chunks) - return PreprocessorResponse(status=True, results=chunks) + return PreprocessorResponse(success=True, preprocessor_output_type=PreprocessingDataType.chunks, results=chunks) async def chain_preprocess( self, preprocessors: PreprocessorChain, preprocessor_inputs: List[PreprocessorInput], - is_rag_chain: Optional[bool] = False, ) -> PreprocessorResponse: return await self.preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs) diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 0bf43eb89..6a639a44b 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -21,6 +21,7 @@ from llama_stack.apis.inference import Inference from llama_stack.apis.preprocessing import ( Preprocessing, PreprocessingDataFormat, + PreprocessingDataType, PreprocessorChainElement, PreprocessorInput, ) @@ -81,9 +82,19 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): preprocessors=preprocessor_chain, preprocessor_inputs=preprocessor_inputs ) - chunks = preprocessor_response.results + if not preprocessor_response.success: + log.error("Preprocessor chain returned an error") + return + if preprocessor_response.preprocessor_output_type != PreprocessingDataType.chunks: + log.error( + f"Preprocessor chain returned {preprocessor_response.preprocessor_output_type} instead of {PreprocessingDataType.chunks}" + ) + return + + chunks = preprocessor_response.results if not chunks: + log.error("No chunks returned by the preprocessor chain") return await self.vector_io_api.insert_chunks(