diff --git a/llama_stack/apis/preprocessing/preprocessing.py b/llama_stack/apis/preprocessing/preprocessing.py index 1e2f0fa49..9e472cc1d 100644 --- a/llama_stack/apis/preprocessing/preprocessing.py +++ b/llama_stack/apis/preprocessing/preprocessing.py @@ -23,10 +23,22 @@ class PreprocessingDataType(Enum): chunks = "chunks" +class PreprocessingDataFormat(Enum): + pdf = "pdf" + docx = "docx" + xlsx = "xlsx" + pptx = "pptx" + md = "md" + json = "json" + html = "html" + csv = "csv" + + @json_schema_type class PreprocessingInput(BaseModel): preprocessor_input_id: str preprocessor_input_type: Optional[PreprocessingDataType] = None + preprocessor_input_format: Optional[PreprocessingDataFormat] = None path_or_content: str | InterleavedContent | URL @@ -52,5 +64,5 @@ class Preprocessing(Protocol): self, preprocessor_id: str, preprocessor_inputs: List[PreprocessingInput], - options: PreprocessorOptions, + options: Optional[PreprocessorOptions] = None, ) -> PreprocessingResponse: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 561a67e27..19a2e6249 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -502,9 +502,10 @@ class PreprocessingRouter(Preprocessing): self, preprocessor_id: str, preprocessor_inputs: List[PreprocessingInput], - options: PreprocessorOptions, + options: Optional[PreprocessorOptions] = None, ) -> PreprocessingResponse: return await self.routing_table.get_provider_impl(preprocessor_id).preprocess( + preprocessor_id=preprocessor_id, preprocessor_inputs=preprocessor_inputs, options=options, ) diff --git a/llama_stack/providers/inline/preprocessing/basic/basic.py b/llama_stack/providers/inline/preprocessing/basic/basic.py index 539dbe586..df2d138cd 100644 --- a/llama_stack/providers/inline/preprocessing/basic/basic.py +++ b/llama_stack/providers/inline/preprocessing/basic/basic.py @@ -5,13 +5,14 @@ # the root directory of this source tree. import logging import re -from typing import List +from typing import List, Optional import httpx from llama_stack.apis.common.content_types import URL from llama_stack.apis.preprocessing import ( Preprocessing, + PreprocessingDataFormat, PreprocessingDataType, PreprocessingInput, PreprocessingResponse, @@ -54,22 +55,26 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): self, preprocessor_id: str, preprocessor_inputs: List[PreprocessingInput], - options: PreprocessorOptions, + options: Optional[PreprocessorOptions] = None, ) -> PreprocessingResponse: results = [] for inp in preprocessor_inputs: - is_pdf = options["binary_document_type"] == "pdf" - input_type = self._resolve_input_type(inp, is_pdf) + input_type = self._resolve_input_type(inp) if input_type == PreprocessingDataType.document_uri: - document = await self._fetch_document(inp, is_pdf) + document = await self._fetch_document(inp) if document is None: continue elif input_type == PreprocessingDataType.binary_document: document = inp.path_or_content - if not is_pdf: - log.error(f"Unsupported binary document type: {options['binary_document_type']}") + if inp.preprocessor_input_format is None: + log.error(f"Binary document format is not provided for {inp.preprocessor_input_id}, skipping it") + continue + if inp.preprocessor_input_format != PreprocessingDataFormat.pdf: + log.error( + f"Unsupported binary document type {inp.preprocessor_input_format} for {inp.preprocessor_input_id}, skipping it" + ) continue elif input_type == PreprocessingDataType.raw_text_document: document = interleaved_content_as_str(inp.path_or_content) @@ -77,7 +82,7 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): log.error(f"Unexpected preprocessor input type: {inp.preprocessor_input_type}") continue - if is_pdf: + if inp.preprocessor_input_format == PreprocessingDataFormat.pdf: document = parse_pdf(document) results.append(document) @@ -85,7 +90,7 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): return PreprocessingResponse(status=True, results=results) @staticmethod - async def _resolve_input_type(preprocessor_input: PreprocessingInput, is_pdf: bool) -> PreprocessingDataType: + async def _resolve_input_type(preprocessor_input: PreprocessingInput) -> PreprocessingDataType: if preprocessor_input.preprocessor_input_type is not None: return preprocessor_input.preprocessor_input_type @@ -93,13 +98,13 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): return PreprocessingDataType.document_uri if InclineBasicPreprocessorImpl.URL_VALIDATION_PATTERN.match(preprocessor_input.path_or_content): return PreprocessingDataType.document_uri - if is_pdf: + if preprocessor_input.preprocessor_input_format == PreprocessingDataFormat.pdf: return PreprocessingDataType.binary_document return PreprocessingDataType.raw_text_document @staticmethod - async def _fetch_document(preprocessor_input: PreprocessingInput, is_pdf: bool) -> str | None: + async def _fetch_document(preprocessor_input: PreprocessingInput) -> str | None: if isinstance(preprocessor_input.path_or_content, str): url = preprocessor_input.path_or_content if not InclineBasicPreprocessorImpl.URL_VALIDATION_PATTERN.match(url): @@ -118,4 +123,9 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): async with httpx.AsyncClient() as client: r = await client.get(url) - return r.content if is_pdf else r.text + + return r.content if preprocessor_input.preprocessor_input_format == PreprocessingDataFormat.pdf else r.text + + @staticmethod + def is_pdf(preprocessor_input: PreprocessingInput): + return diff --git a/llama_stack/providers/inline/preprocessing/docling/docling.py b/llama_stack/providers/inline/preprocessing/docling/docling.py index 34df1ea4a..90b5fd912 100644 --- a/llama_stack/providers/inline/preprocessing/docling/docling.py +++ b/llama_stack/providers/inline/preprocessing/docling/docling.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import logging -from typing import List +from typing import List, Optional from docling.document_converter import DocumentConverter from docling_core.transforms.chunker.hybrid_chunker import HybridChunker @@ -51,7 +51,7 @@ class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate self, preprocessor_id: str, preprocessor_inputs: List[PreprocessingInput], - options: PreprocessorOptions, + options: Optional[PreprocessorOptions] = None, ) -> PreprocessingResponse: results = [] diff --git a/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py b/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py index ea44413de..542178db3 100644 --- a/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py +++ b/llama_stack/providers/inline/preprocessing/simple_chunking/simple_chunking.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import logging from enum import Enum -from typing import List, Tuple +from typing import List, Optional, Tuple from llama_models.llama3.api import Tokenizer @@ -49,7 +49,7 @@ class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate): self, preprocessor_id: str, preprocessor_inputs: List[PreprocessingInput], - options: PreprocessorOptions, + options: Optional[PreprocessorOptions] = None, ) -> PreprocessingResponse: chunks = [] @@ -64,9 +64,11 @@ class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate): return PreprocessingResponse(status=True, results=chunks) def _resolve_chunk_size_params(self, options: PreprocessorOptions) -> Tuple[int, int]: - window_len = options.get(str(SimpleChunkingOptions.chunk_size_in_tokens), self.config.chunk_size_in_tokens) + window_len = (options or {}).get( + str(SimpleChunkingOptions.chunk_size_in_tokens), self.config.chunk_size_in_tokens + ) - chunk_overlap_ratio = options.get( + chunk_overlap_ratio = (options or {}).get( str(SimpleChunkingOptions.chunk_overlap_ratio), self.config.chunk_overlap_ratio ) overlap_len = window_len // chunk_overlap_ratio diff --git a/llama_stack/providers/inline/tool_runtime/rag/__init__.py b/llama_stack/providers/inline/tool_runtime/rag/__init__.py index 15118c9df..292e1a4e7 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/__init__.py +++ b/llama_stack/providers/inline/tool_runtime/rag/__init__.py @@ -14,6 +14,6 @@ from .config import RagToolRuntimeConfig async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[str, Any]): from .memory import MemoryToolRuntimeImpl - impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) + impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.preprocessing]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 4b3f7d9e7..a0e0573a7 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -18,6 +18,7 @@ from llama_stack.apis.common.content_types import ( TextContentItem, ) from llama_stack.apis.inference import Inference +from llama_stack.apis.preprocessing import Preprocessing, PreprocessingDataFormat, PreprocessingInput from llama_stack.apis.tools import ( RAGDocument, RAGQueryConfig, @@ -30,10 +31,6 @@ from llama_stack.apis.tools import ( ) from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import ToolsProtocolPrivate -from llama_stack.providers.utils.memory.vector_store import ( - content_from_doc, - make_overlapped_chunks, -) from .config import RagToolRuntimeConfig from .context_retriever import generate_rag_query @@ -51,10 +48,12 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): config: RagToolRuntimeConfig, vector_io_api: VectorIO, inference_api: Inference, + preprocessing_api: Preprocessing, ): self.config = config self.vector_io_api = vector_io_api self.inference_api = inference_api + self.preprocessing_api = preprocessing_api async def initialize(self): pass @@ -68,17 +67,17 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): vector_db_id: str, chunk_size_in_tokens: int = 512, ) -> None: - chunks = [] - for doc in documents: - content = await content_from_doc(doc) - chunks.extend( - make_overlapped_chunks( - doc.document_id, - content, - chunk_size_in_tokens, - chunk_size_in_tokens // 4, - ) - ) + preprocessing_inputs = [self._rag_document_to_preprocessing_input(d) for d in documents] + + conversion_response = await self.preprocessing_api.preprocess( + preprocessor_id="builtin::basic", preprocessor_inputs=preprocessing_inputs + ) + converted_inputs = conversion_response.results + + chunking_response = await self.preprocessing_api.preprocess( + preprocessor_id="builtin::chunking", preprocessor_inputs=converted_inputs + ) + chunks = chunking_response.results if not chunks: return @@ -196,3 +195,16 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): content=result.content, metadata=result.metadata, ) + + @staticmethod + def _rag_document_to_preprocessing_input(document: RAGDocument) -> PreprocessingInput: + if document.mime_type == "application/pdf": + preprocessor_input_format = PreprocessingDataFormat.pdf + else: + preprocessor_input_format = None + + return PreprocessingInput( + preprocessor_input_id=document.document_id, + preprocessor_input_format=preprocessor_input_format, + path_or_content=document.content, + ) diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 95ea2dcf9..a08df4c90 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -34,7 +34,7 @@ def available_providers() -> List[ProviderSpec]: ], module="llama_stack.providers.inline.tool_runtime.rag", config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig", - api_dependencies=[Api.vector_io, Api.inference], + api_dependencies=[Api.vector_io, Api.inference, Api.preprocessing], ), InlineProviderSpec( api=Api.tool_runtime,