Initial implementation of RAG operator using the preprocessing endpoint.

This commit is contained in:
ilya-kolchinsky 2025-03-05 13:43:26 +01:00
parent c2bd31eb5c
commit 16764a2f06
8 changed files with 74 additions and 37 deletions

View file

@ -23,10 +23,22 @@ class PreprocessingDataType(Enum):
chunks = "chunks" chunks = "chunks"
class PreprocessingDataFormat(Enum):
pdf = "pdf"
docx = "docx"
xlsx = "xlsx"
pptx = "pptx"
md = "md"
json = "json"
html = "html"
csv = "csv"
@json_schema_type @json_schema_type
class PreprocessingInput(BaseModel): class PreprocessingInput(BaseModel):
preprocessor_input_id: str preprocessor_input_id: str
preprocessor_input_type: Optional[PreprocessingDataType] = None preprocessor_input_type: Optional[PreprocessingDataType] = None
preprocessor_input_format: Optional[PreprocessingDataFormat] = None
path_or_content: str | InterleavedContent | URL path_or_content: str | InterleavedContent | URL
@ -52,5 +64,5 @@ class Preprocessing(Protocol):
self, self,
preprocessor_id: str, preprocessor_id: str,
preprocessor_inputs: List[PreprocessingInput], preprocessor_inputs: List[PreprocessingInput],
options: PreprocessorOptions, options: Optional[PreprocessorOptions] = None,
) -> PreprocessingResponse: ... ) -> PreprocessingResponse: ...

View file

@ -502,9 +502,10 @@ class PreprocessingRouter(Preprocessing):
self, self,
preprocessor_id: str, preprocessor_id: str,
preprocessor_inputs: List[PreprocessingInput], preprocessor_inputs: List[PreprocessingInput],
options: PreprocessorOptions, options: Optional[PreprocessorOptions] = None,
) -> PreprocessingResponse: ) -> PreprocessingResponse:
return await self.routing_table.get_provider_impl(preprocessor_id).preprocess( return await self.routing_table.get_provider_impl(preprocessor_id).preprocess(
preprocessor_id=preprocessor_id,
preprocessor_inputs=preprocessor_inputs, preprocessor_inputs=preprocessor_inputs,
options=options, options=options,
) )

View file

@ -5,13 +5,14 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
import re import re
from typing import List from typing import List, Optional
import httpx import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.preprocessing import ( from llama_stack.apis.preprocessing import (
Preprocessing, Preprocessing,
PreprocessingDataFormat,
PreprocessingDataType, PreprocessingDataType,
PreprocessingInput, PreprocessingInput,
PreprocessingResponse, PreprocessingResponse,
@ -54,22 +55,26 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
self, self,
preprocessor_id: str, preprocessor_id: str,
preprocessor_inputs: List[PreprocessingInput], preprocessor_inputs: List[PreprocessingInput],
options: PreprocessorOptions, options: Optional[PreprocessorOptions] = None,
) -> PreprocessingResponse: ) -> PreprocessingResponse:
results = [] results = []
for inp in preprocessor_inputs: for inp in preprocessor_inputs:
is_pdf = options["binary_document_type"] == "pdf" input_type = self._resolve_input_type(inp)
input_type = self._resolve_input_type(inp, is_pdf)
if input_type == PreprocessingDataType.document_uri: if input_type == PreprocessingDataType.document_uri:
document = await self._fetch_document(inp, is_pdf) document = await self._fetch_document(inp)
if document is None: if document is None:
continue continue
elif input_type == PreprocessingDataType.binary_document: elif input_type == PreprocessingDataType.binary_document:
document = inp.path_or_content document = inp.path_or_content
if not is_pdf: if inp.preprocessor_input_format is None:
log.error(f"Unsupported binary document type: {options['binary_document_type']}") log.error(f"Binary document format is not provided for {inp.preprocessor_input_id}, skipping it")
continue
if inp.preprocessor_input_format != PreprocessingDataFormat.pdf:
log.error(
f"Unsupported binary document type {inp.preprocessor_input_format} for {inp.preprocessor_input_id}, skipping it"
)
continue continue
elif input_type == PreprocessingDataType.raw_text_document: elif input_type == PreprocessingDataType.raw_text_document:
document = interleaved_content_as_str(inp.path_or_content) document = interleaved_content_as_str(inp.path_or_content)
@ -77,7 +82,7 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
log.error(f"Unexpected preprocessor input type: {inp.preprocessor_input_type}") log.error(f"Unexpected preprocessor input type: {inp.preprocessor_input_type}")
continue continue
if is_pdf: if inp.preprocessor_input_format == PreprocessingDataFormat.pdf:
document = parse_pdf(document) document = parse_pdf(document)
results.append(document) results.append(document)
@ -85,7 +90,7 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
return PreprocessingResponse(status=True, results=results) return PreprocessingResponse(status=True, results=results)
@staticmethod @staticmethod
async def _resolve_input_type(preprocessor_input: PreprocessingInput, is_pdf: bool) -> PreprocessingDataType: async def _resolve_input_type(preprocessor_input: PreprocessingInput) -> PreprocessingDataType:
if preprocessor_input.preprocessor_input_type is not None: if preprocessor_input.preprocessor_input_type is not None:
return preprocessor_input.preprocessor_input_type return preprocessor_input.preprocessor_input_type
@ -93,13 +98,13 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
return PreprocessingDataType.document_uri return PreprocessingDataType.document_uri
if InclineBasicPreprocessorImpl.URL_VALIDATION_PATTERN.match(preprocessor_input.path_or_content): if InclineBasicPreprocessorImpl.URL_VALIDATION_PATTERN.match(preprocessor_input.path_or_content):
return PreprocessingDataType.document_uri return PreprocessingDataType.document_uri
if is_pdf: if preprocessor_input.preprocessor_input_format == PreprocessingDataFormat.pdf:
return PreprocessingDataType.binary_document return PreprocessingDataType.binary_document
return PreprocessingDataType.raw_text_document return PreprocessingDataType.raw_text_document
@staticmethod @staticmethod
async def _fetch_document(preprocessor_input: PreprocessingInput, is_pdf: bool) -> str | None: async def _fetch_document(preprocessor_input: PreprocessingInput) -> str | None:
if isinstance(preprocessor_input.path_or_content, str): if isinstance(preprocessor_input.path_or_content, str):
url = preprocessor_input.path_or_content url = preprocessor_input.path_or_content
if not InclineBasicPreprocessorImpl.URL_VALIDATION_PATTERN.match(url): if not InclineBasicPreprocessorImpl.URL_VALIDATION_PATTERN.match(url):
@ -118,4 +123,9 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.get(url) r = await client.get(url)
return r.content if is_pdf else r.text
return r.content if preprocessor_input.preprocessor_input_format == PreprocessingDataFormat.pdf else r.text
@staticmethod
def is_pdf(preprocessor_input: PreprocessingInput):
return

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
from typing import List from typing import List, Optional
from docling.document_converter import DocumentConverter from docling.document_converter import DocumentConverter
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
@ -51,7 +51,7 @@ class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate
self, self,
preprocessor_id: str, preprocessor_id: str,
preprocessor_inputs: List[PreprocessingInput], preprocessor_inputs: List[PreprocessingInput],
options: PreprocessorOptions, options: Optional[PreprocessorOptions] = None,
) -> PreprocessingResponse: ) -> PreprocessingResponse:
results = [] results = []

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
from enum import Enum from enum import Enum
from typing import List, Tuple from typing import List, Optional, Tuple
from llama_models.llama3.api import Tokenizer from llama_models.llama3.api import Tokenizer
@ -49,7 +49,7 @@ class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate):
self, self,
preprocessor_id: str, preprocessor_id: str,
preprocessor_inputs: List[PreprocessingInput], preprocessor_inputs: List[PreprocessingInput],
options: PreprocessorOptions, options: Optional[PreprocessorOptions] = None,
) -> PreprocessingResponse: ) -> PreprocessingResponse:
chunks = [] chunks = []
@ -64,9 +64,11 @@ class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate):
return PreprocessingResponse(status=True, results=chunks) return PreprocessingResponse(status=True, results=chunks)
def _resolve_chunk_size_params(self, options: PreprocessorOptions) -> Tuple[int, int]: def _resolve_chunk_size_params(self, options: PreprocessorOptions) -> Tuple[int, int]:
window_len = options.get(str(SimpleChunkingOptions.chunk_size_in_tokens), self.config.chunk_size_in_tokens) window_len = (options or {}).get(
str(SimpleChunkingOptions.chunk_size_in_tokens), self.config.chunk_size_in_tokens
)
chunk_overlap_ratio = options.get( chunk_overlap_ratio = (options or {}).get(
str(SimpleChunkingOptions.chunk_overlap_ratio), self.config.chunk_overlap_ratio str(SimpleChunkingOptions.chunk_overlap_ratio), self.config.chunk_overlap_ratio
) )
overlap_len = window_len // chunk_overlap_ratio overlap_len = window_len // chunk_overlap_ratio

View file

@ -14,6 +14,6 @@ from .config import RagToolRuntimeConfig
async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[str, Any]): async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[str, Any]):
from .memory import MemoryToolRuntimeImpl from .memory import MemoryToolRuntimeImpl
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.preprocessing])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -18,6 +18,7 @@ from llama_stack.apis.common.content_types import (
TextContentItem, TextContentItem,
) )
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.preprocessing import Preprocessing, PreprocessingDataFormat, PreprocessingInput
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
RAGDocument, RAGDocument,
RAGQueryConfig, RAGQueryConfig,
@ -30,10 +31,6 @@ from llama_stack.apis.tools import (
) )
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
make_overlapped_chunks,
)
from .config import RagToolRuntimeConfig from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query from .context_retriever import generate_rag_query
@ -51,10 +48,12 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
config: RagToolRuntimeConfig, config: RagToolRuntimeConfig,
vector_io_api: VectorIO, vector_io_api: VectorIO,
inference_api: Inference, inference_api: Inference,
preprocessing_api: Preprocessing,
): ):
self.config = config self.config = config
self.vector_io_api = vector_io_api self.vector_io_api = vector_io_api
self.inference_api = inference_api self.inference_api = inference_api
self.preprocessing_api = preprocessing_api
async def initialize(self): async def initialize(self):
pass pass
@ -68,17 +67,17 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
vector_db_id: str, vector_db_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
) -> None: ) -> None:
chunks = [] preprocessing_inputs = [self._rag_document_to_preprocessing_input(d) for d in documents]
for doc in documents:
content = await content_from_doc(doc) conversion_response = await self.preprocessing_api.preprocess(
chunks.extend( preprocessor_id="builtin::basic", preprocessor_inputs=preprocessing_inputs
make_overlapped_chunks( )
doc.document_id, converted_inputs = conversion_response.results
content,
chunk_size_in_tokens, chunking_response = await self.preprocessing_api.preprocess(
chunk_size_in_tokens // 4, preprocessor_id="builtin::chunking", preprocessor_inputs=converted_inputs
) )
) chunks = chunking_response.results
if not chunks: if not chunks:
return return
@ -196,3 +195,16 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
content=result.content, content=result.content,
metadata=result.metadata, metadata=result.metadata,
) )
@staticmethod
def _rag_document_to_preprocessing_input(document: RAGDocument) -> PreprocessingInput:
if document.mime_type == "application/pdf":
preprocessor_input_format = PreprocessingDataFormat.pdf
else:
preprocessor_input_format = None
return PreprocessingInput(
preprocessor_input_id=document.document_id,
preprocessor_input_format=preprocessor_input_format,
path_or_content=document.content,
)

View file

@ -34,7 +34,7 @@ def available_providers() -> List[ProviderSpec]:
], ],
module="llama_stack.providers.inline.tool_runtime.rag", module="llama_stack.providers.inline.tool_runtime.rag",
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig", config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
api_dependencies=[Api.vector_io, Api.inference], api_dependencies=[Api.vector_io, Api.inference, Api.preprocessing],
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.tool_runtime, api=Api.tool_runtime,