mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 11:08:20 +00:00
Added a draft implementation of the preprocessor chain.
This commit is contained in:
parent
16764a2f06
commit
b981181b25
7 changed files with 180 additions and 46 deletions
|
@ -35,7 +35,7 @@ class PreprocessingDataFormat(Enum):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PreprocessingInput(BaseModel):
|
class PreprocessorInput(BaseModel):
|
||||||
preprocessor_input_id: str
|
preprocessor_input_id: str
|
||||||
preprocessor_input_type: Optional[PreprocessingDataType] = None
|
preprocessor_input_type: Optional[PreprocessingDataType] = None
|
||||||
preprocessor_input_format: Optional[PreprocessingDataFormat] = None
|
preprocessor_input_format: Optional[PreprocessingDataFormat] = None
|
||||||
|
@ -46,7 +46,16 @@ PreprocessorOptions = Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PreprocessingResponse(BaseModel):
|
class PreprocessorChainElement(BaseModel):
|
||||||
|
preprocessor_id: str
|
||||||
|
options: Optional[PreprocessorOptions] = None
|
||||||
|
|
||||||
|
|
||||||
|
PreprocessorChain = List[PreprocessorChainElement]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PreprocessorResponse(BaseModel):
|
||||||
status: bool
|
status: bool
|
||||||
results: Optional[List[str | InterleavedContent | Chunk]] = None
|
results: Optional[List[str | InterleavedContent | Chunk]] = None
|
||||||
|
|
||||||
|
@ -59,10 +68,21 @@ class PreprocessorStore(Protocol):
|
||||||
class Preprocessing(Protocol):
|
class Preprocessing(Protocol):
|
||||||
preprocessor_store: PreprocessorStore
|
preprocessor_store: PreprocessorStore
|
||||||
|
|
||||||
|
input_types: List[PreprocessingDataType]
|
||||||
|
output_types: List[PreprocessingDataType]
|
||||||
|
|
||||||
@webmethod(route="/preprocess", method="POST")
|
@webmethod(route="/preprocess", method="POST")
|
||||||
async def preprocess(
|
async def preprocess(
|
||||||
self,
|
self,
|
||||||
preprocessor_id: str,
|
preprocessor_id: str,
|
||||||
preprocessor_inputs: List[PreprocessingInput],
|
preprocessor_inputs: List[PreprocessorInput],
|
||||||
options: Optional[PreprocessorOptions] = None,
|
options: Optional[PreprocessorOptions] = None,
|
||||||
) -> PreprocessingResponse: ...
|
) -> PreprocessorResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/chain_preprocess", method="POST")
|
||||||
|
async def chain_preprocess(
|
||||||
|
self,
|
||||||
|
preprocessors: PreprocessorChain,
|
||||||
|
preprocessor_inputs: List[PreprocessorInput],
|
||||||
|
is_rag_chain: Optional[bool] = False,
|
||||||
|
) -> PreprocessorResponse: ...
|
||||||
|
|
|
@ -34,7 +34,13 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.apis.preprocessing import Preprocessing, PreprocessingInput, PreprocessingResponse, PreprocessorOptions
|
from llama_stack.apis.preprocessing import (
|
||||||
|
Preprocessing,
|
||||||
|
PreprocessorChain,
|
||||||
|
PreprocessorInput,
|
||||||
|
PreprocessorOptions,
|
||||||
|
PreprocessorResponse,
|
||||||
|
)
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
from llama_stack.apis.scoring import (
|
from llama_stack.apis.scoring import (
|
||||||
ScoreBatchResponse,
|
ScoreBatchResponse,
|
||||||
|
@ -52,6 +58,7 @@ from llama_stack.apis.tools import (
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
|
from llama_stack.distribution.utils.chain import execute_preprocessor_chain
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
|
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
|
||||||
|
|
||||||
|
@ -501,11 +508,20 @@ class PreprocessingRouter(Preprocessing):
|
||||||
async def preprocess(
|
async def preprocess(
|
||||||
self,
|
self,
|
||||||
preprocessor_id: str,
|
preprocessor_id: str,
|
||||||
preprocessor_inputs: List[PreprocessingInput],
|
preprocessor_inputs: List[PreprocessorInput],
|
||||||
options: Optional[PreprocessorOptions] = None,
|
options: Optional[PreprocessorOptions] = None,
|
||||||
) -> PreprocessingResponse:
|
) -> PreprocessorResponse:
|
||||||
return await self.routing_table.get_provider_impl(preprocessor_id).preprocess(
|
return await self.routing_table.get_provider_impl(preprocessor_id).preprocess(
|
||||||
preprocessor_id=preprocessor_id,
|
preprocessor_id=preprocessor_id,
|
||||||
preprocessor_inputs=preprocessor_inputs,
|
preprocessor_inputs=preprocessor_inputs,
|
||||||
options=options,
|
options=options,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def chain_preprocess(
|
||||||
|
self,
|
||||||
|
preprocessors: PreprocessorChain,
|
||||||
|
preprocessor_inputs: List[PreprocessorInput],
|
||||||
|
is_rag_chain: Optional[bool] = False,
|
||||||
|
) -> PreprocessorResponse:
|
||||||
|
preprocessor_impls = [self.routing_table.get_provider_impl(p.preprocessor_id) for p in preprocessors]
|
||||||
|
return await execute_preprocessor_chain(preprocessors, preprocessor_impls, preprocessor_inputs, is_rag_chain)
|
||||||
|
|
71
llama_stack/distribution/utils/chain.py
Normal file
71
llama_stack/distribution/utils/chain.py
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
import logging
|
||||||
|
from itertools import pairwise
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_stack.apis.preprocessing import (
|
||||||
|
Preprocessing,
|
||||||
|
PreprocessingDataType,
|
||||||
|
PreprocessorChain,
|
||||||
|
PreprocessorInput,
|
||||||
|
PreprocessorResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_chain(chain_impls: List[Preprocessing], is_rag_chain: bool) -> bool:
|
||||||
|
if len(chain_impls) == 0:
|
||||||
|
log.error("Empty preprocessing chain was provided")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if is_rag_chain and PreprocessingDataType.chunks not in chain_impls[-1].output_types:
|
||||||
|
log.error(
|
||||||
|
f"RAG preprocessing chain must end with a chunk-producing preprocessor, but the last preprocessor in the provided chain only supports {chain_impls[-1].output_types}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
for current_preprocessor, next_preprocessor in pairwise(chain_impls):
|
||||||
|
current_output_types = current_preprocessor.output_types
|
||||||
|
next_input_types = next_preprocessor.input_types
|
||||||
|
|
||||||
|
if len(list(set(current_output_types) & set(next_input_types))) == 0:
|
||||||
|
log.error(
|
||||||
|
f"Incompatible input ({current_output_types}) and output({next_input_types}) preprocessor data types"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_preprocessor_chain(
|
||||||
|
preprocessor_chain: PreprocessorChain,
|
||||||
|
preprocessor_chain_impls: List[Preprocessing],
|
||||||
|
preprocessor_inputs: List[PreprocessorInput],
|
||||||
|
is_rag_chain: bool,
|
||||||
|
) -> PreprocessorResponse:
|
||||||
|
if not validate_chain(preprocessor_chain_impls, is_rag_chain):
|
||||||
|
return PreprocessorResponse(status=False, results=[])
|
||||||
|
|
||||||
|
current_inputs = preprocessor_inputs
|
||||||
|
current_outputs = []
|
||||||
|
|
||||||
|
# TODO: replace with a parallel implementation
|
||||||
|
for i, current_params in enumerate(preprocessor_chain):
|
||||||
|
current_impl = preprocessor_chain_impls[i]
|
||||||
|
response = await current_impl.preprocess(
|
||||||
|
preprocessor_id=current_params.preprocessor_id,
|
||||||
|
preprocessor_inputs=current_inputs,
|
||||||
|
options=current_params.options,
|
||||||
|
)
|
||||||
|
if not response.status:
|
||||||
|
log.error(f"Preprocessor {current_params.preprocessor_id} returned an error")
|
||||||
|
return PreprocessorResponse(status=False, results=[])
|
||||||
|
current_outputs = response.results
|
||||||
|
current_inputs = current_outputs
|
||||||
|
|
||||||
|
return PreprocessorResponse(status=True, results=current_outputs)
|
|
@ -14,10 +14,11 @@ from llama_stack.apis.preprocessing import (
|
||||||
Preprocessing,
|
Preprocessing,
|
||||||
PreprocessingDataFormat,
|
PreprocessingDataFormat,
|
||||||
PreprocessingDataType,
|
PreprocessingDataType,
|
||||||
PreprocessingInput,
|
|
||||||
PreprocessingResponse,
|
|
||||||
Preprocessor,
|
Preprocessor,
|
||||||
|
PreprocessorChain,
|
||||||
|
PreprocessorInput,
|
||||||
PreprocessorOptions,
|
PreprocessorOptions,
|
||||||
|
PreprocessorResponse,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate
|
from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate
|
||||||
from llama_stack.providers.inline.preprocessing.basic.config import InlineBasicPreprocessorConfig
|
from llama_stack.providers.inline.preprocessing.basic.config import InlineBasicPreprocessorConfig
|
||||||
|
@ -29,14 +30,14 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
|
class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
|
||||||
# this preprocessor can either receive documents (text or binary) or document URIs
|
# this preprocessor can either receive documents (text or binary) or document URIs
|
||||||
INPUT_TYPES = [
|
input_types = [
|
||||||
PreprocessingDataType.binary_document,
|
PreprocessingDataType.binary_document,
|
||||||
PreprocessingDataType.raw_text_document,
|
PreprocessingDataType.raw_text_document,
|
||||||
PreprocessingDataType.document_uri,
|
PreprocessingDataType.document_uri,
|
||||||
]
|
]
|
||||||
|
|
||||||
# this preprocessor optionally retrieves the documents and converts them into plain text
|
# this preprocessor optionally retrieves the documents and converts them into plain text
|
||||||
OUTPUT_TYPES = [PreprocessingDataType.raw_text_document]
|
output_types = [PreprocessingDataType.raw_text_document]
|
||||||
|
|
||||||
URL_VALIDATION_PATTERN = re.compile("^(https?://|file://|data:)")
|
URL_VALIDATION_PATTERN = re.compile("^(https?://|file://|data:)")
|
||||||
|
|
||||||
|
@ -54,9 +55,9 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
|
||||||
async def preprocess(
|
async def preprocess(
|
||||||
self,
|
self,
|
||||||
preprocessor_id: str,
|
preprocessor_id: str,
|
||||||
preprocessor_inputs: List[PreprocessingInput],
|
preprocessor_inputs: List[PreprocessorInput],
|
||||||
options: Optional[PreprocessorOptions] = None,
|
options: Optional[PreprocessorOptions] = None,
|
||||||
) -> PreprocessingResponse:
|
) -> PreprocessorResponse:
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for inp in preprocessor_inputs:
|
for inp in preprocessor_inputs:
|
||||||
|
@ -87,10 +88,18 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
|
||||||
|
|
||||||
results.append(document)
|
results.append(document)
|
||||||
|
|
||||||
return PreprocessingResponse(status=True, results=results)
|
return PreprocessorResponse(status=True, results=results)
|
||||||
|
|
||||||
|
async def chain_preprocess(
|
||||||
|
self,
|
||||||
|
preprocessors: PreprocessorChain,
|
||||||
|
preprocessor_inputs: List[PreprocessorInput],
|
||||||
|
is_rag_chain: Optional[bool] = False,
|
||||||
|
) -> PreprocessorResponse:
|
||||||
|
return await self.preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _resolve_input_type(preprocessor_input: PreprocessingInput) -> PreprocessingDataType:
|
async def _resolve_input_type(preprocessor_input: PreprocessorInput) -> PreprocessingDataType:
|
||||||
if preprocessor_input.preprocessor_input_type is not None:
|
if preprocessor_input.preprocessor_input_type is not None:
|
||||||
return preprocessor_input.preprocessor_input_type
|
return preprocessor_input.preprocessor_input_type
|
||||||
|
|
||||||
|
@ -104,7 +113,7 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
|
||||||
return PreprocessingDataType.raw_text_document
|
return PreprocessingDataType.raw_text_document
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _fetch_document(preprocessor_input: PreprocessingInput) -> str | None:
|
async def _fetch_document(preprocessor_input: PreprocessorInput) -> str | None:
|
||||||
if isinstance(preprocessor_input.path_or_content, str):
|
if isinstance(preprocessor_input.path_or_content, str):
|
||||||
url = preprocessor_input.path_or_content
|
url = preprocessor_input.path_or_content
|
||||||
if not InclineBasicPreprocessorImpl.URL_VALIDATION_PATTERN.match(url):
|
if not InclineBasicPreprocessorImpl.URL_VALIDATION_PATTERN.match(url):
|
||||||
|
@ -125,7 +134,3 @@ class InclineBasicPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
|
||||||
r = await client.get(url)
|
r = await client.get(url)
|
||||||
|
|
||||||
return r.content if preprocessor_input.preprocessor_input_format == PreprocessingDataFormat.pdf else r.text
|
return r.content if preprocessor_input.preprocessor_input_format == PreprocessingDataFormat.pdf else r.text
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_pdf(preprocessor_input: PreprocessingInput):
|
|
||||||
return
|
|
||||||
|
|
|
@ -13,10 +13,11 @@ from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.preprocessing import (
|
from llama_stack.apis.preprocessing import (
|
||||||
Preprocessing,
|
Preprocessing,
|
||||||
PreprocessingDataType,
|
PreprocessingDataType,
|
||||||
PreprocessingInput,
|
|
||||||
PreprocessingResponse,
|
|
||||||
Preprocessor,
|
Preprocessor,
|
||||||
|
PreprocessorChain,
|
||||||
|
PreprocessorInput,
|
||||||
PreprocessorOptions,
|
PreprocessorOptions,
|
||||||
|
PreprocessorResponse,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import Chunk
|
from llama_stack.apis.vector_io import Chunk
|
||||||
from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate
|
from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate
|
||||||
|
@ -27,10 +28,10 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
|
class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate):
|
||||||
# this preprocessor receives URLs / paths to documents as input
|
# this preprocessor receives URLs / paths to documents as input
|
||||||
INPUT_TYPES = [PreprocessingDataType.document_uri]
|
input_types = [PreprocessingDataType.document_uri]
|
||||||
|
|
||||||
# this preprocessor either only converts the documents into a text format, or also chunks them
|
# this preprocessor either only converts the documents into a text format, or also chunks them
|
||||||
OUTPUT_TYPES = [PreprocessingDataType.raw_text_document, PreprocessingDataType.chunks]
|
output_types = [PreprocessingDataType.raw_text_document, PreprocessingDataType.chunks]
|
||||||
|
|
||||||
def __init__(self, config: InlineDoclingConfig) -> None:
|
def __init__(self, config: InlineDoclingConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -50,9 +51,9 @@ class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate
|
||||||
async def preprocess(
|
async def preprocess(
|
||||||
self,
|
self,
|
||||||
preprocessor_id: str,
|
preprocessor_id: str,
|
||||||
preprocessor_inputs: List[PreprocessingInput],
|
preprocessor_inputs: List[PreprocessorInput],
|
||||||
options: Optional[PreprocessorOptions] = None,
|
options: Optional[PreprocessorOptions] = None,
|
||||||
) -> PreprocessingResponse:
|
) -> PreprocessorResponse:
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for inp in preprocessor_inputs:
|
for inp in preprocessor_inputs:
|
||||||
|
@ -74,4 +75,12 @@ class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate
|
||||||
result = converted_document.export_to_markdown()
|
result = converted_document.export_to_markdown()
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
return PreprocessingResponse(status=True, results=results)
|
return PreprocessorResponse(status=True, results=results)
|
||||||
|
|
||||||
|
async def chain_preprocess(
|
||||||
|
self,
|
||||||
|
preprocessors: PreprocessorChain,
|
||||||
|
preprocessor_inputs: List[PreprocessorInput],
|
||||||
|
is_rag_chain: Optional[bool] = False,
|
||||||
|
) -> PreprocessorResponse:
|
||||||
|
return await self.preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs)
|
||||||
|
|
|
@ -12,10 +12,11 @@ from llama_models.llama3.api import Tokenizer
|
||||||
from llama_stack.apis.preprocessing import (
|
from llama_stack.apis.preprocessing import (
|
||||||
Preprocessing,
|
Preprocessing,
|
||||||
PreprocessingDataType,
|
PreprocessingDataType,
|
||||||
PreprocessingInput,
|
|
||||||
PreprocessingResponse,
|
|
||||||
Preprocessor,
|
Preprocessor,
|
||||||
|
PreprocessorChain,
|
||||||
|
PreprocessorInput,
|
||||||
PreprocessorOptions,
|
PreprocessorOptions,
|
||||||
|
PreprocessorResponse,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import Chunk
|
from llama_stack.apis.vector_io import Chunk
|
||||||
from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate
|
from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate
|
||||||
|
@ -31,8 +32,8 @@ class SimpleChunkingOptions(Enum):
|
||||||
|
|
||||||
class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate):
|
class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate):
|
||||||
# this preprocessor receives plain text and returns chunks
|
# this preprocessor receives plain text and returns chunks
|
||||||
INPUT_TYPES = [PreprocessingDataType.raw_text_document]
|
input_types = [PreprocessingDataType.raw_text_document]
|
||||||
OUTPUT_TYPES = [PreprocessingDataType.chunks]
|
output_types = [PreprocessingDataType.chunks]
|
||||||
|
|
||||||
def __init__(self, config: InclineSimpleChunkingConfig) -> None:
|
def __init__(self, config: InclineSimpleChunkingConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -48,9 +49,9 @@ class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate):
|
||||||
async def preprocess(
|
async def preprocess(
|
||||||
self,
|
self,
|
||||||
preprocessor_id: str,
|
preprocessor_id: str,
|
||||||
preprocessor_inputs: List[PreprocessingInput],
|
preprocessor_inputs: List[PreprocessorInput],
|
||||||
options: Optional[PreprocessorOptions] = None,
|
options: Optional[PreprocessorOptions] = None,
|
||||||
) -> PreprocessingResponse:
|
) -> PreprocessorResponse:
|
||||||
chunks = []
|
chunks = []
|
||||||
|
|
||||||
window_len, overlap_len = self._resolve_chunk_size_params(options)
|
window_len, overlap_len = self._resolve_chunk_size_params(options)
|
||||||
|
@ -61,7 +62,15 @@ class InclineSimpleChunkingImpl(Preprocessing, PreprocessorsProtocolPrivate):
|
||||||
)
|
)
|
||||||
chunks.extend(new_chunks)
|
chunks.extend(new_chunks)
|
||||||
|
|
||||||
return PreprocessingResponse(status=True, results=chunks)
|
return PreprocessorResponse(status=True, results=chunks)
|
||||||
|
|
||||||
|
async def chain_preprocess(
|
||||||
|
self,
|
||||||
|
preprocessors: PreprocessorChain,
|
||||||
|
preprocessor_inputs: List[PreprocessorInput],
|
||||||
|
is_rag_chain: Optional[bool] = False,
|
||||||
|
) -> PreprocessorResponse:
|
||||||
|
return await self.preprocess(preprocessor_id="", preprocessor_inputs=preprocessor_inputs)
|
||||||
|
|
||||||
def _resolve_chunk_size_params(self, options: PreprocessorOptions) -> Tuple[int, int]:
|
def _resolve_chunk_size_params(self, options: PreprocessorOptions) -> Tuple[int, int]:
|
||||||
window_len = (options or {}).get(
|
window_len = (options or {}).get(
|
||||||
|
|
|
@ -18,7 +18,12 @@ from llama_stack.apis.common.content_types import (
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.preprocessing import Preprocessing, PreprocessingDataFormat, PreprocessingInput
|
from llama_stack.apis.preprocessing import (
|
||||||
|
Preprocessing,
|
||||||
|
PreprocessingDataFormat,
|
||||||
|
PreprocessorChainElement,
|
||||||
|
PreprocessorInput,
|
||||||
|
)
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
RAGDocument,
|
RAGDocument,
|
||||||
RAGQueryConfig,
|
RAGQueryConfig,
|
||||||
|
@ -67,17 +72,16 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunk_size_in_tokens: int = 512,
|
chunk_size_in_tokens: int = 512,
|
||||||
) -> None:
|
) -> None:
|
||||||
preprocessing_inputs = [self._rag_document_to_preprocessing_input(d) for d in documents]
|
preprocessor_inputs = [self._rag_document_to_preprocessor_input(d) for d in documents]
|
||||||
|
preprocessor_chain = [
|
||||||
conversion_response = await self.preprocessing_api.preprocess(
|
PreprocessorChainElement(preprocessor_id="builtin::basic"),
|
||||||
preprocessor_id="builtin::basic", preprocessor_inputs=preprocessing_inputs
|
PreprocessorChainElement(preprocessor_id="builtin::chunking"),
|
||||||
|
]
|
||||||
|
preprocessor_response = await self.preprocessing_api.chain_preprocess(
|
||||||
|
preprocessors=preprocessor_chain, preprocessor_inputs=preprocessor_inputs
|
||||||
)
|
)
|
||||||
converted_inputs = conversion_response.results
|
|
||||||
|
|
||||||
chunking_response = await self.preprocessing_api.preprocess(
|
chunks = preprocessor_response.results
|
||||||
preprocessor_id="builtin::chunking", preprocessor_inputs=converted_inputs
|
|
||||||
)
|
|
||||||
chunks = chunking_response.results
|
|
||||||
|
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return
|
return
|
||||||
|
@ -197,13 +201,13 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _rag_document_to_preprocessing_input(document: RAGDocument) -> PreprocessingInput:
|
def _rag_document_to_preprocessor_input(document: RAGDocument) -> PreprocessorInput:
|
||||||
if document.mime_type == "application/pdf":
|
if document.mime_type == "application/pdf":
|
||||||
preprocessor_input_format = PreprocessingDataFormat.pdf
|
preprocessor_input_format = PreprocessingDataFormat.pdf
|
||||||
else:
|
else:
|
||||||
preprocessor_input_format = None
|
preprocessor_input_format = None
|
||||||
|
|
||||||
return PreprocessingInput(
|
return PreprocessorInput(
|
||||||
preprocessor_input_id=document.document_id,
|
preprocessor_input_id=document.document_id,
|
||||||
preprocessor_input_format=preprocessor_input_format,
|
preprocessor_input_format=preprocessor_input_format,
|
||||||
path_or_content=document.content,
|
path_or_content=document.content,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue