Added draft implementation of built-in preprocessing for RAG.

This commit is contained in:
ilya-kolchinsky 2025-03-04 15:22:29 +01:00
parent 5014de434e
commit 1a6e71c61f
9 changed files with 299 additions and 4 deletions

View file

@ -26,8 +26,8 @@ class PreprocessingDataType(Enum):
@json_schema_type
class PreprocessingInput(BaseModel):
preprocessor_input_id: str
preprocessor_input_type: Optional[PreprocessingDataType]
path_or_content: str | URL
preprocessor_input_type: Optional[PreprocessingDataType] = None
path_or_content: str | InterleavedContent | URL
PreprocessorOptions = Dict[str, Any]
@ -36,7 +36,7 @@ PreprocessorOptions = Dict[str, Any]
@json_schema_type
class PreprocessingResponse(BaseModel):
status: bool
results: Optional[List[str | InterleavedContent | Chunk]]
results: Optional[List[str | InterleavedContent | Chunk]] = None
class PreprocessorStore(Protocol):

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import InlineBasicPreprocessorConfig
async def get_provider_impl(
config: InlineBasicPreprocessorConfig,
_deps,
):
from .basic import InclineBasicPreprocessorImpl
impl = InclineBasicPreprocessorImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,121 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import re
from typing import List
import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.preprocessing import (
Preprocessing,
PreprocessingDataType,
PreprocessingInput,
PreprocessingResponse,
Preprocessor,
PreprocessorOptions,
)
from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate
from llama_stack.providers.inline.preprocessing.basic.config import InlineBasicPreprocessorConfig
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import content_from_data, parse_pdf
log = logging.getLogger(__name__)
class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
# this preprocessor can either receive documents (text or binary) or document URIs
INPUT_TYPES = [
PreprocessingDataType.binary_document,
PreprocessingDataType.raw_text_document,
PreprocessingDataType.document_uri,
]
# this preprocessor optionally retrieves the documents and converts them into plain text
OUTPUT_TYPES = [PreprocessingDataType.raw_text_document]
URL_VALIDATION_PATTERN = re.compile("^(https?://|file://|data:)")
def __init__(self, config: InlineBasicPreprocessorConfig) -> None:
self.config = config
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def register_preprocessor(self, preprocessor: Preprocessor) -> None: ...
async def unregister_preprocessor(self, preprocessor_id: str) -> None: ...
async def preprocess(
self,
preprocessor_id: str,
preprocessor_inputs: List[PreprocessingInput],
options: PreprocessorOptions,
) -> PreprocessingResponse:
results = []
for inp in preprocessor_inputs:
is_pdf = options["binary_document_type"] == "pdf"
input_type = self._resolve_input_type(inp, is_pdf)
if input_type == PreprocessingDataType.document_uri:
document = await self._fetch_document(inp, is_pdf)
if document is None:
continue
elif input_type == PreprocessingDataType.binary_document:
document = inp.path_or_content
if not is_pdf:
log.error(f"Unsupported binary document type: {options['binary_document_type']}")
continue
elif input_type == PreprocessingDataType.raw_text_document:
document = interleaved_content_as_str(inp.path_or_content)
else:
log.error(f"Unexpected preprocessor input type: {inp.preprocessor_input_type}")
continue
if is_pdf:
document = parse_pdf(document)
results.append(document)
return PreprocessingResponse(status=True, results=results)
@staticmethod
async def _resolve_input_type(preprocessor_input: PreprocessingInput, is_pdf: bool) -> PreprocessingDataType:
if preprocessor_input.preprocessor_input_type is not None:
return preprocessor_input.preprocessor_input_type
if isinstance(preprocessor_input.path_or_content, URL):
return PreprocessingDataType.document_uri
if InclineBasicPreprocessorImpl.URL_VALIDATION_PATTERN.match(preprocessor_input.path_or_content):
return PreprocessingDataType.document_uri
if is_pdf:
return PreprocessingDataType.binary_document
return PreprocessingDataType.raw_text_document
@staticmethod
async def _fetch_document(preprocessor_input: PreprocessingInput, is_pdf: bool) -> str | None:
if isinstance(preprocessor_input.path_or_content, str):
url = preprocessor_input.path_or_content
if not InclineBasicPreprocessorImpl.URL_VALIDATION_PATTERN.match(url):
log.error(f"Unexpected URL: {url}")
return None
elif isinstance(preprocessor_input.path_or_content, URL):
url = preprocessor_input.path_or_content.uri
else:
log.error(
f"Unexpected type {type(preprocessor_input.path_or_content)} for input {preprocessor_input.path_or_content}, skipping this input."
)
return None
if url.startswith("data:"):
return content_from_data(url)
async with httpx.AsyncClient() as client:
r = await client.get(url)
return r.content if is_pdf else r.text

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class InlineBasicPreprocessorConfig(BaseModel): ...

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import List
from docling.document_converter import DocumentConverter
@ -21,6 +22,8 @@ from llama_stack.apis.vector_io import Chunk
from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate
from llama_stack.providers.inline.preprocessing.docling import InlineDoclingConfig
log = logging.getLogger(__name__)
class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
# this preprocessor receives URLs / paths to documents as input
@ -58,7 +61,10 @@ class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate
elif isinstance(inp.path_or_content, URL):
url = inp.path_or_content.uri
else:
raise ValueError(f"Unexpected type {type(inp.path_or_content)} for input {inp.path_or_content}")
log.error(
f"Unexpected type {type(inp.path_or_content)} for input {inp.path_or_content}, skipping this input."
)
continue
converted_document = self.converter.convert(url).document
if self.config.chunk:

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import InclineSimpleChunkingConfig
async def get_provider_impl(
config: InclineSimpleChunkingConfig,
_deps,
):
from .simple_chunking import InclineSimpleChunkingImpl
impl = InclineSimpleChunkingImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class InclineSimpleChunkingConfig(BaseModel):
chunk_size_in_tokens: int = 512
chunk_overlap_ratio: int = 4

View file

@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from enum import Enum
from typing import List, Tuple
from llama_models.llama3.api import Tokenizer
from llama_stack.apis.preprocessing import (
Preprocessing,
PreprocessingDataType,
PreprocessingInput,
PreprocessingResponse,
Preprocessor,
PreprocessorOptions,
)
from llama_stack.apis.vector_io import Chunk
from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate
from llama_stack.providers.inline.preprocessing.simple_chunking import InclineSimpleChunkingConfig
log = logging.getLogger(__name__)
class SimpleChunkingOptions(Enum):
chunk_size_in_tokens = "chunk_size_in_tokens"
chunk_overlap_ratio = "chunk_overlap_ratio"
class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate):
# this preprocessor receives plain text and returns chunks
INPUT_TYPES = [PreprocessingDataType.raw_text_document]
OUTPUT_TYPES = [PreprocessingDataType.chunks]
def __init__(self, config: InclineSimpleChunkingConfig) -> None:
self.config = config
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def register_preprocessor(self, preprocessor: Preprocessor) -> None: ...
async def unregister_preprocessor(self, preprocessor_id: str) -> None: ...
async def preprocess(
self,
preprocessor_id: str,
preprocessor_inputs: List[PreprocessingInput],
options: PreprocessorOptions,
) -> PreprocessingResponse:
chunks = []
window_len, overlap_len = self._resolve_chunk_size_params(options)
for inp in preprocessor_inputs:
new_chunks = self.make_overlapped_chunks(
inp.preprocessor_input_id, inp.path_or_content, window_len, overlap_len
)
chunks.extend(new_chunks)
return PreprocessingResponse(status=True, results=chunks)
def _resolve_chunk_size_params(self, options: PreprocessorOptions) -> Tuple[int, int]:
window_len = options.get(str(SimpleChunkingOptions.chunk_size_in_tokens), self.config.chunk_size_in_tokens)
chunk_overlap_ratio = options.get(
str(SimpleChunkingOptions.chunk_overlap_ratio), self.config.chunk_overlap_ratio
)
overlap_len = window_len // chunk_overlap_ratio
return window_len, overlap_len
@staticmethod
def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> List[Chunk]:
tokenizer = Tokenizer.get_instance()
tokens = tokenizer.encode(text, bos=False, eos=False)
chunks = []
for i in range(0, len(tokens), window_len - overlap_len):
toks = tokens[i : i + window_len]
chunk = tokenizer.decode(toks)
# chunk is a string
chunks.append(
Chunk(
content=chunk,
metadata={
"token_count": len(toks),
"document_id": document_id,
},
)
)
return chunks

View file

@ -23,4 +23,20 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.inline.preprocessing.docling.InlineDoclingConfig",
api_dependencies=[],
),
InlineProviderSpec(
api=Api.preprocessing,
provider_type="inline::basic",
pip_packages=["httpx", "pypdf"],
module="llama_stack.providers.inline.preprocessing.basic",
config_class="llama_stack.providers.inline.preprocessing.basic.InlineBasicPreprocessorConfig",
api_dependencies=[],
),
InlineProviderSpec(
api=Api.preprocessing,
provider_type="inline::simple_chunking",
pip_packages=[],
module="llama_stack.providers.inline.preprocessing.simple_chunking",
config_class="llama_stack.providers.inline.preprocessing.simple_chunking.InclineSimpleChunkingConfig",
api_dependencies=[],
),
]