# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import logging from itertools import pairwise from typing import List from llama_stack.apis.preprocessing import ( Preprocessing, PreprocessorChain, PreprocessorInput, PreprocessorResponse, ) log = logging.getLogger(__name__) def validate_chain(chain_impls: List[Preprocessing]) -> bool: if len(chain_impls) == 0: log.error("Empty preprocessing chain was provided") return False for current_preprocessor, next_preprocessor in pairwise(chain_impls): current_output_types = current_preprocessor.output_types next_input_types = next_preprocessor.input_types if len(list(set(current_output_types) & set(next_input_types))) == 0: log.error( f"Incompatible input ({current_output_types}) and output({next_input_types}) preprocessor data types" ) return False return True async def execute_preprocessor_chain( preprocessor_chain: PreprocessorChain, preprocessor_chain_impls: List[Preprocessing], preprocessor_inputs: List[PreprocessorInput], ) -> PreprocessorResponse: if not validate_chain(preprocessor_chain_impls): return PreprocessorResponse(success=False, results=[]) current_inputs = preprocessor_inputs current_outputs = [] current_result_type = None # TODO: replace with a parallel implementation for i, current_params in enumerate(preprocessor_chain): current_impl = preprocessor_chain_impls[i] response = await current_impl.preprocess( preprocessor_id=current_params.preprocessor_id, preprocessor_inputs=current_inputs, options=current_params.options, ) if not response.success: log.error(f"Preprocessor {current_params.preprocessor_id} returned an error") return PreprocessorResponse( success=False, preprocessor_output_type=response.preprocessor_output_type, results=[] ) current_outputs = response.results current_inputs = current_outputs current_result_type = response.preprocessor_output_type return PreprocessorResponse(success=True, preprocessor_output_type=current_result_type, results=current_outputs)