Added a draft implementation of the preprocessor chain.

This commit is contained in:
ilya-kolchinsky 2025-03-05 17:17:17 +01:00
parent 16764a2f06
commit b981181b25
7 changed files with 180 additions and 46 deletions

View file

@ -35,7 +35,7 @@ class PreprocessingDataFormat(Enum):
@json_schema_type @json_schema_type
class PreprocessingInput(BaseModel): class PreprocessorInput(BaseModel):
preprocessor_input_id: str preprocessor_input_id: str
preprocessor_input_type: Optional[PreprocessingDataType] = None preprocessor_input_type: Optional[PreprocessingDataType] = None
preprocessor_input_format: Optional[PreprocessingDataFormat] = None preprocessor_input_format: Optional[PreprocessingDataFormat] = None
@ -46,7 +46,16 @@ PreprocessorOptions = Dict[str, Any]
@json_schema_type @json_schema_type
class PreprocessingResponse(BaseModel): class PreprocessorChainElement(BaseModel):
preprocessor_id: str
options: Optional[PreprocessorOptions] = None
PreprocessorChain = List[PreprocessorChainElement]
@json_schema_type
class PreprocessorResponse(BaseModel):
status: bool status: bool
results: Optional[List[str | InterleavedContent | Chunk]] = None results: Optional[List[str | InterleavedContent | Chunk]] = None
@ -59,10 +68,21 @@ class PreprocessorStore(Protocol):
class Preprocessing(Protocol): class Preprocessing(Protocol):
preprocessor_store: PreprocessorStore preprocessor_store: PreprocessorStore
input_types: List[PreprocessingDataType]
output_types: List[PreprocessingDataType]
@webmethod(route="/preprocess", method="POST") @webmethod(route="/preprocess", method="POST")
async def preprocess( async def preprocess(
self, self,
preprocessor_id: str, preprocessor_id: str,
preprocessor_inputs: List[PreprocessingInput], preprocessor_inputs: List[PreprocessorInput],
options: Optional[PreprocessorOptions] = None, options: Optional[PreprocessorOptions] = None,
) -> PreprocessingResponse: ... ) -> PreprocessorResponse: ...
@webmethod(route="/chain_preprocess", method="POST")
async def chain_preprocess(
self,
preprocessors: PreprocessorChain,
preprocessor_inputs: List[PreprocessorInput],
is_rag_chain: Optional[bool] = False,
) -> PreprocessorResponse: ...

View file

@ -34,7 +34,13 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.apis.preprocessing import Preprocessing, PreprocessingInput, PreprocessingResponse, PreprocessorOptions from llama_stack.apis.preprocessing import (
Preprocessing,
PreprocessorChain,
PreprocessorInput,
PreprocessorOptions,
PreprocessorResponse,
)
from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import ( from llama_stack.apis.scoring import (
ScoreBatchResponse, ScoreBatchResponse,
@ -52,6 +58,7 @@ from llama_stack.apis.tools import (
ToolRuntime, ToolRuntime,
) )
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.distribution.utils.chain import execute_preprocessor_chain
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
@ -501,11 +508,20 @@ class PreprocessingRouter(Preprocessing):
async def preprocess( async def preprocess(
self, self,
preprocessor_id: str, preprocessor_id: str,
preprocessor_inputs: List[PreprocessingInput], preprocessor_inputs: List[PreprocessorInput],
options: Optional[PreprocessorOptions] = None, options: Optional[PreprocessorOptions] = None,
) -> PreprocessingResponse: ) -> PreprocessorResponse:
return await self.routing_table.get_provider_impl(preprocessor_id).preprocess( return await self.routing_table.get_provider_impl(preprocessor_id).preprocess(
preprocessor_id=preprocessor_id, preprocessor_id=preprocessor_id,
preprocessor_inputs=preprocessor_inputs, preprocessor_inputs=preprocessor_inputs,
options=options, options=options,
) )
async def chain_preprocess(
self,
preprocessors: PreprocessorChain,
preprocessor_inputs: List[PreprocessorInput],
is_rag_chain: Optional[bool] = False,
) -> PreprocessorResponse:
preprocessor_impls = [self.routing_table.get_provider_impl(p.preprocessor_id) for p in preprocessors]
return await execute_preprocessor_chain(preprocessors, preprocessor_impls, preprocessor_inputs, is_rag_chain)

View file

@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from itertools import pairwise
from typing import List
from llama_stack.apis.preprocessing import (
Preprocessing,
PreprocessingDataType,
PreprocessorChain,
PreprocessorInput,
PreprocessorResponse,
)
log = logging.getLogger(__name__)
def validate_chain(chain_impls: List[Preprocessing], is_rag_chain: bool) -> bool:
if len(chain_impls) == 0:
log.error("Empty preprocessing chain was provided")
return False
if is_rag_chain and PreprocessingDataType.chunks not in chain_impls[-1].output_types:
log.error(
f"RAG preprocessing chain must end with a chunk-producing preprocessor, but the last preprocessor in the provided chain only supports {chain_impls[-1].output_types}"
)
return False
for current_preprocessor, next_preprocessor in pairwise(chain_impls):
current_output_types = current_preprocessor.output_types
next_input_types = next_preprocessor.input_types
if len(list(set(current_output_types) & set(next_input_types))) == 0:
log.error(
f"Incompatible input ({current_output_types}) and output({next_input_types}) preprocessor data types"
)
return False
return True
async def execute_preprocessor_chain(
preprocessor_chain: PreprocessorChain,
preprocessor_chain_impls: List[Preprocessing],
preprocessor_inputs: List[PreprocessorInput],
is_rag_chain: bool,
) -> PreprocessorResponse:
if not validate_chain(preprocessor_chain_impls, is_rag_chain):
return PreprocessorResponse(status=False, results=[])
current_inputs = preprocessor_inputs
current_outputs = []
# TODO: replace with a parallel implementation
for i, current_params in enumerate(preprocessor_chain):
current_impl = preprocessor_chain_impls[i]
response = await current_impl.preprocess(
preprocessor_id=current_params.preprocessor_id,
preprocessor_inputs=current_inputs,
options=current_params.options,
)
if not response.status:
log.error(f"Preprocessor {current_params.preprocessor_id} returned an error")
return PreprocessorResponse(status=False, results=[])
current_outputs = response.results
current_inputs = current_outputs
return PreprocessorResponse(status=True, results=current_outputs)

View file

@ -14,10 +14,11 @@ from llama_stack.apis.preprocessing import (
Preprocessing, Preprocessing,
PreprocessingDataFormat, PreprocessingDataFormat,
PreprocessingDataType, PreprocessingDataType,
PreprocessingInput,
PreprocessingResponse,
Preprocessor, Preprocessor,
PreprocessorChain,
PreprocessorInput,
PreprocessorOptions, PreprocessorOptions,
PreprocessorResponse,
) )
from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate
from llama_stack.providers.inline.preprocessing.basic.config import InlineBasicPreprocessorConfig from llama_stack.providers.inline.preprocessing.basic.config import InlineBasicPreprocessorConfig
@ -29,14 +30,14 @@ log = logging.getLogger(__name__)
class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
# this preprocessor can either receive documents (text or binary) or document URIs # this preprocessor can either receive documents (text or binary) or document URIs
INPUT_TYPES = [ input_types = [
PreprocessingDataType.binary_document, PreprocessingDataType.binary_document,
PreprocessingDataType.raw_text_document, PreprocessingDataType.raw_text_document,
PreprocessingDataType.document_uri, PreprocessingDataType.document_uri,
] ]
# this preprocessor optionally retrieves the documents and converts them into plain text # this preprocessor optionally retrieves the documents and converts them into plain text
OUTPUT_TYPES = [PreprocessingDataType.raw_text_document] output_types = [PreprocessingDataType.raw_text_document]
URL_VALIDATION_PATTERN = re.compile("^(https?://|file://|data:)") URL_VALIDATION_PATTERN = re.compile("^(https?://|file://|data:)")
@ -54,9 +55,9 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
async def preprocess( async def preprocess(
self, self,
preprocessor_id: str, preprocessor_id: str,
preprocessor_inputs: List[PreprocessingInput], preprocessor_inputs: List[PreprocessorInput],
options: Optional[PreprocessorOptions] = None, options: Optional[PreprocessorOptions] = None,
) -> PreprocessingResponse: ) -> PreprocessorResponse:
results = [] results = []
for inp in preprocessor_inputs: for inp in preprocessor_inputs:
@ -87,10 +88,18 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
results.append(document) results.append(document)
return PreprocessingResponse(status=True, results=results) return PreprocessorResponse(status=True, results=results)
async def chain_preprocess(
self,
preprocessors: PreprocessorChain,
preprocessor_inputs: List[PreprocessorInput],
is_rag_chain: Optional[bool] = False,
) -> PreprocessorResponse:
return await self.preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs)
@staticmethod @staticmethod
async def _resolve_input_type(preprocessor_input: PreprocessingInput) -> PreprocessingDataType: async def _resolve_input_type(preprocessor_input: PreprocessorInput) -> PreprocessingDataType:
if preprocessor_input.preprocessor_input_type is not None: if preprocessor_input.preprocessor_input_type is not None:
return preprocessor_input.preprocessor_input_type return preprocessor_input.preprocessor_input_type
@ -104,7 +113,7 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
return PreprocessingDataType.raw_text_document return PreprocessingDataType.raw_text_document
@staticmethod @staticmethod
async def _fetch_document(preprocessor_input: PreprocessingInput) -> str | None: async def _fetch_document(preprocessor_input: PreprocessorInput) -> str | None:
if isinstance(preprocessor_input.path_or_content, str): if isinstance(preprocessor_input.path_or_content, str):
url = preprocessor_input.path_or_content url = preprocessor_input.path_or_content
if not InclineBasicPreprocessorImpl.URL_VALIDATION_PATTERN.match(url): if not InclineBasicPreprocessorImpl.URL_VALIDATION_PATTERN.match(url):
@ -125,7 +134,3 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
r = await client.get(url) r = await client.get(url)
return r.content if preprocessor_input.preprocessor_input_format == PreprocessingDataFormat.pdf else r.text return r.content if preprocessor_input.preprocessor_input_format == PreprocessingDataFormat.pdf else r.text
@staticmethod
def is_pdf(preprocessor_input: PreprocessingInput):
return

View file

@ -13,10 +13,11 @@ from llama_stack.apis.common.content_types import URL
from llama_stack.apis.preprocessing import ( from llama_stack.apis.preprocessing import (
Preprocessing, Preprocessing,
PreprocessingDataType, PreprocessingDataType,
PreprocessingInput,
PreprocessingResponse,
Preprocessor, Preprocessor,
PreprocessorChain,
PreprocessorInput,
PreprocessorOptions, PreprocessorOptions,
PreprocessorResponse,
) )
from llama_stack.apis.vector_io import Chunk from llama_stack.apis.vector_io import Chunk
from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate
@ -27,10 +28,10 @@ log = logging.getLogger(__name__)
class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
# this preprocessor receives URLs / paths to documents as input # this preprocessor receives URLs / paths to documents as input
INPUT_TYPES = [PreprocessingDataType.document_uri] input_types = [PreprocessingDataType.document_uri]
# this preprocessor either only converts the documents into a text format, or also chunks them # this preprocessor either only converts the documents into a text format, or also chunks them
OUTPUT_TYPES = [PreprocessingDataType.raw_text_document, PreprocessingDataType.chunks] output_types = [PreprocessingDataType.raw_text_document, PreprocessingDataType.chunks]
def __init__(self, config: InlineDoclingConfig) -> None: def __init__(self, config: InlineDoclingConfig) -> None:
self.config = config self.config = config
@ -50,9 +51,9 @@ class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate
async def preprocess( async def preprocess(
self, self,
preprocessor_id: str, preprocessor_id: str,
preprocessor_inputs: List[PreprocessingInput], preprocessor_inputs: List[PreprocessorInput],
options: Optional[PreprocessorOptions] = None, options: Optional[PreprocessorOptions] = None,
) -> PreprocessingResponse: ) -> PreprocessorResponse:
results = [] results = []
for inp in preprocessor_inputs: for inp in preprocessor_inputs:
@ -74,4 +75,12 @@ class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate
result = converted_document.export_to_markdown() result = converted_document.export_to_markdown()
results.append(result) results.append(result)
return PreprocessingResponse(status=True, results=results) return PreprocessorResponse(status=True, results=results)
async def chain_preprocess(
self,
preprocessors: PreprocessorChain,
preprocessor_inputs: List[PreprocessorInput],
is_rag_chain: Optional[bool] = False,
) -> PreprocessorResponse:
return await self.preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs)

View file

@ -12,10 +12,11 @@ from llama_models.llama3.api import Tokenizer
from llama_stack.apis.preprocessing import ( from llama_stack.apis.preprocessing import (
Preprocessing, Preprocessing,
PreprocessingDataType, PreprocessingDataType,
PreprocessingInput,
PreprocessingResponse,
Preprocessor, Preprocessor,
PreprocessorChain,
PreprocessorInput,
PreprocessorOptions, PreprocessorOptions,
PreprocessorResponse,
) )
from llama_stack.apis.vector_io import Chunk from llama_stack.apis.vector_io import Chunk
from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate
@ -31,8 +32,8 @@ class SimpleChunkingOptions(Enum):
class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate): class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate):
# this preprocessor receives plain text and returns chunks # this preprocessor receives plain text and returns chunks
INPUT_TYPES = [PreprocessingDataType.raw_text_document] input_types = [PreprocessingDataType.raw_text_document]
OUTPUT_TYPES = [PreprocessingDataType.chunks] output_types = [PreprocessingDataType.chunks]
def __init__(self, config: InclineSimpleChunkingConfig) -> None: def __init__(self, config: InclineSimpleChunkingConfig) -> None:
self.config = config self.config = config
@ -48,9 +49,9 @@ class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate):
async def preprocess( async def preprocess(
self, self,
preprocessor_id: str, preprocessor_id: str,
preprocessor_inputs: List[PreprocessingInput], preprocessor_inputs: List[PreprocessorInput],
options: Optional[PreprocessorOptions] = None, options: Optional[PreprocessorOptions] = None,
) -> PreprocessingResponse: ) -> PreprocessorResponse:
chunks = [] chunks = []
window_len, overlap_len = self._resolve_chunk_size_params(options) window_len, overlap_len = self._resolve_chunk_size_params(options)
@ -61,7 +62,15 @@ class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate):
) )
chunks.extend(new_chunks) chunks.extend(new_chunks)
return PreprocessingResponse(status=True, results=chunks) return PreprocessorResponse(status=True, results=chunks)
async def chain_preprocess(
self,
preprocessors: PreprocessorChain,
preprocessor_inputs: List[PreprocessorInput],
is_rag_chain: Optional[bool] = False,
) -> PreprocessorResponse:
return await self.preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs)
def _resolve_chunk_size_params(self, options: PreprocessorOptions) -> Tuple[int, int]: def _resolve_chunk_size_params(self, options: PreprocessorOptions) -> Tuple[int, int]:
window_len = (options or {}).get( window_len = (options or {}).get(

View file

@ -18,7 +18,12 @@ from llama_stack.apis.common.content_types import (
TextContentItem, TextContentItem,
) )
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.preprocessing import Preprocessing, PreprocessingDataFormat, PreprocessingInput from llama_stack.apis.preprocessing import (
Preprocessing,
PreprocessingDataFormat,
PreprocessorChainElement,
PreprocessorInput,
)
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
RAGDocument, RAGDocument,
RAGQueryConfig, RAGQueryConfig,
@ -67,17 +72,16 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
vector_db_id: str, vector_db_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
) -> None: ) -> None:
preprocessing_inputs = [self._rag_document_to_preprocessing_input(d) for d in documents] preprocessor_inputs = [self._rag_document_to_preprocessor_input(d) for d in documents]
preprocessor_chain = [
conversion_response = await self.preprocessing_api.preprocess( PreprocessorChainElement(preprocessor_id="builtin::basic"),
preprocessor_id="builtin::basic", preprocessor_inputs=preprocessing_inputs PreprocessorChainElement(preprocessor_id="builtin::chunking"),
]
preprocessor_response = await self.preprocessing_api.chain_preprocess(
preprocessors=preprocessor_chain, preprocessor_inputs=preprocessor_inputs
) )
converted_inputs = conversion_response.results
chunking_response = await self.preprocessing_api.preprocess( chunks = preprocessor_response.results
preprocessor_id="builtin::chunking", preprocessor_inputs=converted_inputs
)
chunks = chunking_response.results
if not chunks: if not chunks:
return return
@ -197,13 +201,13 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
) )
@staticmethod @staticmethod
def _rag_document_to_preprocessing_input(document: RAGDocument) -> PreprocessingInput: def _rag_document_to_preprocessor_input(document: RAGDocument) -> PreprocessorInput:
if document.mime_type == "application/pdf": if document.mime_type == "application/pdf":
preprocessor_input_format = PreprocessingDataFormat.pdf preprocessor_input_format = PreprocessingDataFormat.pdf
else: else:
preprocessor_input_format = None preprocessor_input_format = None
return PreprocessingInput( return PreprocessorInput(
preprocessor_input_id=document.document_id, preprocessor_input_id=document.document_id,
preprocessor_input_format=preprocessor_input_format, preprocessor_input_format=preprocessor_input_format,
path_or_content=document.content, path_or_content=document.content,