diff --git a/llama_stack/apis/preprocessing/preprocessing.py b/llama_stack/apis/preprocessing/preprocessing.py index 083368b5b..a41a649ff 100644 --- a/llama_stack/apis/preprocessing/preprocessing.py +++ b/llama_stack/apis/preprocessing/preprocessing.py @@ -70,19 +70,8 @@ class PreprocessorStore(Protocol): class Preprocessing(Protocol): preprocessor_store: PreprocessorStore - input_types: List[PreprocessingDataType] - output_types: List[PreprocessingDataType] - @webmethod(route="/preprocess", method="POST") async def preprocess( - self, - preprocessor_id: str, - preprocessor_inputs: List[PreprocessingDataElement], - options: Optional[PreprocessorOptions] = None, - ) -> PreprocessorResponse: ... - - @webmethod(route="/chain_preprocess", method="POST") - async def chain_preprocess( self, preprocessors: PreprocessorChain, preprocessor_inputs: List[PreprocessingDataElement], diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index dc932e867..4c7a4cc00 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -45,7 +45,6 @@ from llama_stack.apis.preprocessing import ( Preprocessing, PreprocessingDataElement, PreprocessorChain, - PreprocessorOptions, PreprocessorResponse, ) from llama_stack.apis.safety import RunShieldResponse, Safety @@ -714,22 +713,6 @@ class PreprocessingRouter(Preprocessing): pass async def preprocess( - self, - preprocessor_id: str, - preprocessor_inputs: List[PreprocessingDataElement], - options: Optional[PreprocessorOptions] = None, - ) -> PreprocessorResponse: - logcat.debug( - "core", - f"PreprocessingRouter.preprocess: {preprocessor_id}, {len(preprocessor_inputs)} inputs, options={options}", - ) - return await self.routing_table.get_provider_impl(preprocessor_id).preprocess( - preprocessor_id=preprocessor_id, - preprocessor_inputs=preprocessor_inputs, - options=options, - ) - - async def chain_preprocess( self, preprocessors: PreprocessorChain, preprocessor_inputs: List[PreprocessingDataElement], diff --git a/llama_stack/distribution/utils/chain.py b/llama_stack/distribution/utils/chain.py index 2cd6bd0b2..22580cdd2 100644 --- a/llama_stack/distribution/utils/chain.py +++ b/llama_stack/distribution/utils/chain.py @@ -8,16 +8,16 @@ from itertools import pairwise from typing import List from llama_stack.apis.preprocessing import ( - Preprocessing, PreprocessingDataElement, PreprocessorChain, PreprocessorResponse, ) +from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate log = logging.getLogger(__name__) -def validate_chain(chain_impls: List[Preprocessing]) -> bool: +def validate_chain(chain_impls: List[PreprocessorsProtocolPrivate]) -> bool: if len(chain_impls) == 0: log.error("Empty preprocessing chain was provided") return False @@ -37,7 +37,7 @@ def validate_chain(chain_impls: List[Preprocessing]) -> bool: async def execute_preprocessor_chain( preprocessor_chain: PreprocessorChain, - preprocessor_chain_impls: List[Preprocessing], + preprocessor_chain_impls: List[PreprocessorsProtocolPrivate], preprocessor_inputs: List[PreprocessingDataElement], ) -> PreprocessorResponse: if not validate_chain(preprocessor_chain_impls): @@ -50,7 +50,7 @@ async def execute_preprocessor_chain( # TODO: replace with a parallel implementation for i, current_params in enumerate(preprocessor_chain): current_impl = preprocessor_chain_impls[i] - response = await current_impl.preprocess( + response = await current_impl.do_preprocess( preprocessor_id=current_params.preprocessor_id, preprocessor_inputs=current_inputs, options=current_params.options, diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index f34da79c0..549a953b0 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -13,7 +13,13 @@ from llama_stack.apis.benchmarks import Benchmark from llama_stack.apis.datasets import Dataset from llama_stack.apis.datatypes import Api from llama_stack.apis.models import Model -from llama_stack.apis.preprocessing import Preprocessor +from llama_stack.apis.preprocessing import ( + PreprocessingDataElement, + PreprocessingDataType, + Preprocessor, + PreprocessorOptions, + PreprocessorResponse, +) from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.shields import Shield from llama_stack.apis.tools import Tool @@ -60,10 +66,20 @@ class ToolsProtocolPrivate(Protocol): class PreprocessorsProtocolPrivate(Protocol): + input_types: List[PreprocessingDataType] + output_types: List[PreprocessingDataType] + async def register_preprocessor(self, preprocessor: Preprocessor) -> None: ... async def unregister_preprocessor(self, preprocessor_id: str) -> None: ... + async def do_preprocess( + self, + preprocessor_id: str, + preprocessor_inputs: List[PreprocessingDataElement], + options: Optional[PreprocessorOptions] = None, + ) -> PreprocessorResponse: ... + @json_schema_type class ProviderSpec(BaseModel): diff --git a/llama_stack/providers/inline/preprocessing/basic/basic.py b/llama_stack/providers/inline/preprocessing/basic/basic.py index f31eff3e9..363afc7bc 100644 --- a/llama_stack/providers/inline/preprocessing/basic/basic.py +++ b/llama_stack/providers/inline/preprocessing/basic/basic.py @@ -52,7 +52,7 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): async def unregister_preprocessor(self, preprocessor_id: str) -> None: ... - async def preprocess( + async def do_preprocess( self, preprocessor_id: str, preprocessor_inputs: List[PreprocessingDataElement], @@ -98,12 +98,12 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): success=True, output_data_type=PreprocessingDataType.raw_text_document, results=results ) - async def chain_preprocess( + async def preprocess( self, preprocessors: PreprocessorChain, preprocessor_inputs: List[PreprocessingDataElement], ) -> PreprocessorResponse: - return await self.preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs) + return await self.do_preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs) @staticmethod def _resolve_input_type(preprocessor_input: PreprocessingDataElement) -> PreprocessingDataType: diff --git a/llama_stack/providers/inline/preprocessing/docling/docling.py b/llama_stack/providers/inline/preprocessing/docling/docling.py index 281c72b54..c292e8f89 100644 --- a/llama_stack/providers/inline/preprocessing/docling/docling.py +++ b/llama_stack/providers/inline/preprocessing/docling/docling.py @@ -47,7 +47,7 @@ class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate async def unregister_preprocessor(self, preprocessor_id: str) -> None: ... - async def preprocess( + async def do_preprocess( self, preprocessor_id: str, preprocessor_inputs: List[PreprocessingDataElement], @@ -106,9 +106,9 @@ class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate ) return PreprocessorResponse(success=True, output_data_type=output_data_type, results=results) - async def chain_preprocess( + async def preprocess( self, preprocessors: PreprocessorChain, preprocessor_inputs: List[PreprocessingDataElement], ) -> PreprocessorResponse: - return await self.preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs) + return await self.do_preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs) diff --git a/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py b/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py index 38ed18837..98d2b63d3 100644 --- a/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py +++ b/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py @@ -47,7 +47,7 @@ class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate): async def unregister_preprocessor(self, preprocessor_id: str) -> None: ... - async def preprocess( + async def do_preprocess( self, preprocessor_id: str, preprocessor_inputs: List[PreprocessingDataElement], @@ -72,12 +72,12 @@ class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate): return PreprocessorResponse(success=True, output_data_type=PreprocessingDataType.chunks, results=chunks) - async def chain_preprocess( + async def preprocess( self, preprocessors: PreprocessorChain, preprocessor_inputs: List[PreprocessingDataElement], ) -> PreprocessorResponse: - return await self.preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs) + return await self.do_preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs) def _resolve_chunk_size_params(self, options: PreprocessorOptions) -> Tuple[int, int]: window_len = (options or {}).get( diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index bda130811..c950cde81 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -81,7 +81,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): preprocessor_chain: Optional[PreprocessorChain] = None, ) -> None: preprocessor_inputs = [self._rag_document_to_preprocessor_input(d) for d in documents] - preprocessor_response = await self.preprocessing_api.chain_preprocess( + preprocessor_response = await self.preprocessing_api.preprocess( preprocessors=preprocessor_chain or self.DEFAULT_PREPROCESSING_CHAIN, preprocessor_inputs=preprocessor_inputs, )