From f10a4128984b86d560371ec0b016e68fec78d237 Mon Sep 17 00:00:00 2001 From: ilya-kolchinsky Date: Thu, 6 Mar 2025 16:46:59 +0100 Subject: [PATCH] Fixed multiple bugs. --- .../apis/preprocessing/preprocessing.py | 19 +++--- llama_stack/distribution/routers/routers.py | 6 +- llama_stack/distribution/utils/chain.py | 12 ++-- .../inline/preprocessing/basic/basic.py | 58 ++++++++++--------- .../inline/preprocessing/docling/docling.py | 38 ++++++++---- .../simple_chunking/simple_chunking.py | 20 +++++-- .../inline/tool_runtime/rag/memory.py | 27 ++++----- 7 files changed, 102 insertions(+), 78 deletions(-) diff --git a/llama_stack/apis/preprocessing/preprocessing.py b/llama_stack/apis/preprocessing/preprocessing.py index f5c34becd..083368b5b 100644 --- a/llama_stack/apis/preprocessing/preprocessing.py +++ b/llama_stack/apis/preprocessing/preprocessing.py @@ -32,14 +32,15 @@ class PreprocessingDataFormat(Enum): json = "json" html = "html" csv = "csv" + txt = "txt" @json_schema_type -class PreprocessorInput(BaseModel): - preprocessor_input_id: str - preprocessor_input_type: Optional[PreprocessingDataType] = None - preprocessor_input_format: Optional[PreprocessingDataFormat] = None - path_or_content: str | InterleavedContent | URL +class PreprocessingDataElement(BaseModel): + data_element_id: str + data_element_type: Optional[PreprocessingDataType] = None + data_element_format: Optional[PreprocessingDataFormat] = None + data_element_path_or_content: str | InterleavedContent | URL | Chunk | None PreprocessorOptions = Dict[str, Any] @@ -57,8 +58,8 @@ PreprocessorChain = List[PreprocessorChainElement] @json_schema_type class PreprocessorResponse(BaseModel): success: bool - preprocessor_output_type: PreprocessingDataType - results: Optional[List[str | InterleavedContent | Chunk]] = None + output_data_type: PreprocessingDataType + results: Optional[List[PreprocessingDataElement]] = None class PreprocessorStore(Protocol): @@ -76,7 +77,7 @@ class Preprocessing(Protocol): async def preprocess( self, preprocessor_id: str, - preprocessor_inputs: List[PreprocessorInput], + preprocessor_inputs: List[PreprocessingDataElement], options: Optional[PreprocessorOptions] = None, ) -> PreprocessorResponse: ... @@ -84,5 +85,5 @@ class Preprocessing(Protocol): async def chain_preprocess( self, preprocessors: PreprocessorChain, - preprocessor_inputs: List[PreprocessorInput], + preprocessor_inputs: List[PreprocessingDataElement], ) -> PreprocessorResponse: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 460de5c47..c24ce81d1 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -36,8 +36,8 @@ from llama_stack.apis.inference import ( from llama_stack.apis.models import ModelType from llama_stack.apis.preprocessing import ( Preprocessing, + PreprocessingDataElement, PreprocessorChain, - PreprocessorInput, PreprocessorOptions, PreprocessorResponse, ) @@ -509,7 +509,7 @@ class PreprocessingRouter(Preprocessing): async def preprocess( self, preprocessor_id: str, - preprocessor_inputs: List[PreprocessorInput], + preprocessor_inputs: List[PreprocessingDataElement], options: Optional[PreprocessorOptions] = None, ) -> PreprocessorResponse: return await self.routing_table.get_provider_impl(preprocessor_id).preprocess( @@ -521,7 +521,7 @@ class PreprocessingRouter(Preprocessing): async def chain_preprocess( self, preprocessors: PreprocessorChain, - preprocessor_inputs: List[PreprocessorInput], + preprocessor_inputs: List[PreprocessingDataElement], ) -> PreprocessorResponse: preprocessor_impls = [self.routing_table.get_provider_impl(p.preprocessor_id) for p in preprocessors] return await execute_preprocessor_chain(preprocessors, preprocessor_impls, preprocessor_inputs) diff --git a/llama_stack/distribution/utils/chain.py b/llama_stack/distribution/utils/chain.py index e31e12410..2cd6bd0b2 100644 --- a/llama_stack/distribution/utils/chain.py +++ b/llama_stack/distribution/utils/chain.py @@ -9,8 +9,8 @@ from typing import List from llama_stack.apis.preprocessing import ( Preprocessing, + PreprocessingDataElement, PreprocessorChain, - PreprocessorInput, PreprocessorResponse, ) @@ -38,7 +38,7 @@ def validate_chain(chain_impls: List[Preprocessing]) -> bool: async def execute_preprocessor_chain( preprocessor_chain: PreprocessorChain, preprocessor_chain_impls: List[Preprocessing], - preprocessor_inputs: List[PreprocessorInput], + preprocessor_inputs: List[PreprocessingDataElement], ) -> PreprocessorResponse: if not validate_chain(preprocessor_chain_impls): return PreprocessorResponse(success=False, results=[]) @@ -57,11 +57,9 @@ async def execute_preprocessor_chain( ) if not response.success: log.error(f"Preprocessor {current_params.preprocessor_id} returned an error") - return PreprocessorResponse( - success=False, preprocessor_output_type=response.preprocessor_output_type, results=[] - ) + return PreprocessorResponse(success=False, output_data_type=response.output_data_type, results=[]) current_outputs = response.results current_inputs = current_outputs - current_result_type = response.preprocessor_output_type + current_result_type = response.output_data_type - return PreprocessorResponse(success=True, preprocessor_output_type=current_result_type, results=current_outputs) + return PreprocessorResponse(success=True, output_data_type=current_result_type, results=current_outputs) diff --git a/llama_stack/providers/inline/preprocessing/basic/basic.py b/llama_stack/providers/inline/preprocessing/basic/basic.py index ae7d03b07..f31eff3e9 100644 --- a/llama_stack/providers/inline/preprocessing/basic/basic.py +++ b/llama_stack/providers/inline/preprocessing/basic/basic.py @@ -12,11 +12,11 @@ import httpx from llama_stack.apis.common.content_types import URL from llama_stack.apis.preprocessing import ( Preprocessing, + PreprocessingDataElement, PreprocessingDataFormat, PreprocessingDataType, Preprocessor, PreprocessorChain, - PreprocessorInput, PreprocessorOptions, PreprocessorResponse, ) @@ -55,7 +55,7 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): async def preprocess( self, preprocessor_id: str, - preprocessor_inputs: List[PreprocessorInput], + preprocessor_inputs: List[PreprocessingDataElement], options: Optional[PreprocessorOptions] = None, ) -> PreprocessorResponse: results = [] @@ -68,63 +68,69 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): if document is None: continue elif input_type == PreprocessingDataType.binary_document: - document = inp.path_or_content - if inp.preprocessor_input_format is None: - log.error(f"Binary document format is not provided for {inp.preprocessor_input_id}, skipping it") + document = inp.data_element_path_or_content + if inp.data_element_format is None: + log.error(f"Binary document format is not provided for {inp.data_element_id}, skipping it") continue - if inp.preprocessor_input_format != PreprocessingDataFormat.pdf: + if inp.data_element_format != PreprocessingDataFormat.pdf: log.error( - f"Unsupported binary document type {inp.preprocessor_input_format} for {inp.preprocessor_input_id}, skipping it" + f"Unsupported binary document type {inp.data_element_format} for {inp.data_element_id}, skipping it" ) continue elif input_type == PreprocessingDataType.raw_text_document: - document = interleaved_content_as_str(inp.path_or_content) + document = interleaved_content_as_str(inp.data_element_path_or_content) else: - log.error(f"Unexpected preprocessor input type: {inp.preprocessor_input_type}") + log.error(f"Unexpected preprocessor input type: {input_type}") continue - if inp.preprocessor_input_format == PreprocessingDataFormat.pdf: + if inp.data_element_format == PreprocessingDataFormat.pdf: document = parse_pdf(document) - results.append(document) + new_result = PreprocessingDataElement( + data_element_id=inp.data_element_id, + data_element_type=PreprocessingDataType.raw_text_document, + data_element_format=PreprocessingDataFormat.txt, + data_element_path_or_content=document, + ) + results.append(new_result) return PreprocessorResponse( - success=True, preprocessor_output_type=PreprocessingDataType.raw_text_document, results=results + success=True, output_data_type=PreprocessingDataType.raw_text_document, results=results ) async def chain_preprocess( self, preprocessors: PreprocessorChain, - preprocessor_inputs: List[PreprocessorInput], + preprocessor_inputs: List[PreprocessingDataElement], ) -> PreprocessorResponse: return await self.preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs) @staticmethod - async def _resolve_input_type(preprocessor_input: PreprocessorInput) -> PreprocessingDataType: - if preprocessor_input.preprocessor_input_type is not None: - return preprocessor_input.preprocessor_input_type + def _resolve_input_type(preprocessor_input: PreprocessingDataElement) -> PreprocessingDataType: + if preprocessor_input.data_element_type is not None: + return preprocessor_input.data_element_type - if isinstance(preprocessor_input.path_or_content, URL): + if isinstance(preprocessor_input.data_element_path_or_content, URL): return PreprocessingDataType.document_uri - if InclineBasicPreprocessorImpl.URL_VALIDATION_PATTERN.match(preprocessor_input.path_or_content): + if InclineBasicPreprocessorImpl.URL_VALIDATION_PATTERN.match(preprocessor_input.data_element_path_or_content): return PreprocessingDataType.document_uri - if preprocessor_input.preprocessor_input_format == PreprocessingDataFormat.pdf: + if preprocessor_input.data_element_format == PreprocessingDataFormat.pdf: return PreprocessingDataType.binary_document return PreprocessingDataType.raw_text_document @staticmethod - async def _fetch_document(preprocessor_input: PreprocessorInput) -> str | None: - if isinstance(preprocessor_input.path_or_content, str): - url = preprocessor_input.path_or_content + async def _fetch_document(preprocessor_input: PreprocessingDataElement) -> str | None: + if isinstance(preprocessor_input.data_element_path_or_content, str): + url = preprocessor_input.data_element_path_or_content if not InclineBasicPreprocessorImpl.URL_VALIDATION_PATTERN.match(url): log.error(f"Unexpected URL: {url}") return None - elif isinstance(preprocessor_input.path_or_content, URL): - url = preprocessor_input.path_or_content.uri + elif isinstance(preprocessor_input.data_element_path_or_content, URL): + url = preprocessor_input.data_element_path_or_content.uri else: log.error( - f"Unexpected type {type(preprocessor_input.path_or_content)} for input {preprocessor_input.path_or_content}, skipping this input." + f"Unexpected type {type(preprocessor_input.data_element_path_or_content)} for input {preprocessor_input.data_element_path_or_content}, skipping this input." ) return None @@ -134,4 +140,4 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): async with httpx.AsyncClient() as client: r = await client.get(url) - return r.content if preprocessor_input.preprocessor_input_format == PreprocessingDataFormat.pdf else r.text + return r.content if preprocessor_input.data_element_format == PreprocessingDataFormat.pdf else r.text diff --git a/llama_stack/providers/inline/preprocessing/docling/docling.py b/llama_stack/providers/inline/preprocessing/docling/docling.py index 3492a70c1..5c2641ea7 100644 --- a/llama_stack/providers/inline/preprocessing/docling/docling.py +++ b/llama_stack/providers/inline/preprocessing/docling/docling.py @@ -12,10 +12,11 @@ from docling_core.transforms.chunker.hybrid_chunker import HybridChunker from llama_stack.apis.common.content_types import URL from llama_stack.apis.preprocessing import ( Preprocessing, + PreprocessingDataElement, + PreprocessingDataFormat, PreprocessingDataType, Preprocessor, PreprocessorChain, - PreprocessorInput, PreprocessorOptions, PreprocessorResponse, ) @@ -51,38 +52,51 @@ class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate async def preprocess( self, preprocessor_id: str, - preprocessor_inputs: List[PreprocessorInput], + preprocessor_inputs: List[PreprocessingDataElement], options: Optional[PreprocessorOptions] = None, ) -> PreprocessorResponse: results = [] for inp in preprocessor_inputs: - if isinstance(inp.path_or_content, str): - url = inp.path_or_content - elif isinstance(inp.path_or_content, URL): - url = inp.path_or_content.uri + if isinstance(inp.data_element_path_or_content, str): + url = inp.data_element_path_or_content + elif isinstance(inp.data_element_path_or_content, URL): + url = inp.data_element_path_or_content.uri else: log.error( - f"Unexpected type {type(inp.path_or_content)} for input {inp.path_or_content}, skipping this input." + f"Unexpected type {type(inp.data_element_path_or_content)} for input {inp.data_element_path_or_content}, skipping this input." ) continue converted_document = self.converter.convert(url).document if self.config.chunk: result = self.chunker.chunk(converted_document) - results.extend([Chunk(content=chunk.text, metadata=chunk.meta) for chunk in result]) + for i, chunk in enumerate(result): + raw_chunk = Chunk(content=chunk.text, metadata=chunk.meta) + chunk_data_element = PreprocessingDataElement( + data_element_id=f"{inp.data_element_id}_chunk_{i}", + data_element_type=PreprocessingDataType.chunks, + data_element_format=PreprocessingDataFormat.txt, + data_element_path_or_content=raw_chunk, + ) + results.append(chunk_data_element) else: - result = converted_document.export_to_markdown() + result = PreprocessingDataElement( + data_element_id=inp.data_element_id, + data_element_type=PreprocessingDataType.raw_text_document, + data_element_format=PreprocessingDataFormat.txt, + data_element_path_or_content=converted_document.export_to_markdown(), + ) results.append(result) - preprocessor_output_type = ( + output_data_type = ( PreprocessingDataType.chunks if self.config.chunk else PreprocessingDataType.raw_text_document ) - return PreprocessorResponse(success=True, preprocessor_output_type=preprocessor_output_type, results=results) + return PreprocessorResponse(success=True, output_data_type=output_data_type, results=results) async def chain_preprocess( self, preprocessors: PreprocessorChain, - preprocessor_inputs: List[PreprocessorInput], + preprocessor_inputs: List[PreprocessingDataElement], ) -> PreprocessorResponse: return await self.preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs) diff --git a/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py b/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py index 5aea70079..38ed18837 100644 --- a/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py +++ b/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py @@ -11,10 +11,11 @@ from llama_models.llama3.api import Tokenizer from llama_stack.apis.preprocessing import ( Preprocessing, + PreprocessingDataElement, + PreprocessingDataFormat, PreprocessingDataType, Preprocessor, PreprocessorChain, - PreprocessorInput, PreprocessorOptions, PreprocessorResponse, ) @@ -49,7 +50,7 @@ class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate): async def preprocess( self, preprocessor_id: str, - preprocessor_inputs: List[PreprocessorInput], + preprocessor_inputs: List[PreprocessingDataElement], options: Optional[PreprocessorOptions] = None, ) -> PreprocessorResponse: chunks = [] @@ -58,16 +59,23 @@ class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate): for inp in preprocessor_inputs: new_chunks = self.make_overlapped_chunks( - inp.preprocessor_input_id, inp.path_or_content, window_len, overlap_len + inp.data_element_id, inp.data_element_path_or_content, window_len, overlap_len ) - chunks.extend(new_chunks) + for i, chunk in enumerate(new_chunks): + new_chunk_data_element = PreprocessingDataElement( + data_element_id=f"{inp.data_element_id}_chunk_{i}", + data_element_type=PreprocessingDataType.chunks, + data_element_format=PreprocessingDataFormat.txt, + data_element_path_or_content=chunk, + ) + chunks.append(new_chunk_data_element) - return PreprocessorResponse(success=True, preprocessor_output_type=PreprocessingDataType.chunks, results=chunks) + return PreprocessorResponse(success=True, output_data_type=PreprocessingDataType.chunks, results=chunks) async def chain_preprocess( self, preprocessors: PreprocessorChain, - preprocessor_inputs: List[PreprocessorInput], + preprocessor_inputs: List[PreprocessingDataElement], ) -> PreprocessorResponse: return await self.preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs) diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 6a81b9f16..bda130811 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -20,11 +20,11 @@ from llama_stack.apis.common.content_types import ( from llama_stack.apis.inference import Inference from llama_stack.apis.preprocessing import ( Preprocessing, + PreprocessingDataElement, PreprocessingDataFormat, PreprocessingDataType, PreprocessorChain, PreprocessorChainElement, - PreprocessorInput, ) from llama_stack.apis.tools import ( RAGDocument, @@ -81,10 +81,6 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): preprocessor_chain: Optional[PreprocessorChain] = None, ) -> None: preprocessor_inputs = [self._rag_document_to_preprocessor_input(d) for d in documents] - preprocessor_chain = [ - PreprocessorChainElement(preprocessor_id="builtin::basic"), - PreprocessorChainElement(preprocessor_id="builtin::chunking"), - ] preprocessor_response = await self.preprocessing_api.chain_preprocess( preprocessors=preprocessor_chain or self.DEFAULT_PREPROCESSING_CHAIN, preprocessor_inputs=preprocessor_inputs, @@ -94,9 +90,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): log.error("Preprocessor chain returned an error") return - if preprocessor_response.preprocessor_output_type != PreprocessingDataType.chunks: + if preprocessor_response.output_data_type != PreprocessingDataType.chunks: log.error( - f"Preprocessor chain returned {preprocessor_response.preprocessor_output_type} instead of {PreprocessingDataType.chunks}" + f"Preprocessor chain returned {preprocessor_response.output_data_type} instead of {PreprocessingDataType.chunks}" ) return @@ -105,8 +101,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): log.error("No chunks returned by the preprocessor chain") return + actual_chunks = [chunk.data_element_path_or_content for chunk in chunks] await self.vector_io_api.insert_chunks( - chunks=chunks, + chunks=actual_chunks, vector_db_id=vector_db_id, ) @@ -220,14 +217,14 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): ) @staticmethod - def _rag_document_to_preprocessor_input(document: RAGDocument) -> PreprocessorInput: + def _rag_document_to_preprocessor_input(document: RAGDocument) -> PreprocessingDataElement: if document.mime_type == "application/pdf": - preprocessor_input_format = PreprocessingDataFormat.pdf + data_element_format = PreprocessingDataFormat.pdf else: - preprocessor_input_format = None + data_element_format = None - return PreprocessorInput( - preprocessor_input_id=document.document_id, - preprocessor_input_format=preprocessor_input_format, - path_or_content=document.content, + return PreprocessingDataElement( + data_element_id=document.document_id, + data_element_format=data_element_format, + data_element_path_or_content=document.content, )