mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-03 16:12:16 +00:00
Added a draft implementation of the preprocessor chain.
This commit is contained in:
parent
16764a2f06
commit
b981181b25
7 changed files with 180 additions and 46 deletions
|
|
@ -18,7 +18,12 @@ from llama_stack.apis.common.content_types import (
|
|||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.preprocessing import Preprocessing, PreprocessingDataFormat, PreprocessingInput
|
||||
from llama_stack.apis.preprocessing import (
|
||||
Preprocessing,
|
||||
PreprocessingDataFormat,
|
||||
PreprocessorChainElement,
|
||||
PreprocessorInput,
|
||||
)
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
|
|
@ -67,17 +72,16 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
preprocessing_inputs = [self._rag_document_to_preprocessing_input(d) for d in documents]
|
||||
|
||||
conversion_response = await self.preprocessing_api.preprocess(
|
||||
preprocessor_id="builtin::basic", preprocessor_inputs=preprocessing_inputs
|
||||
preprocessor_inputs = [self._rag_document_to_preprocessor_input(d) for d in documents]
|
||||
preprocessor_chain = [
|
||||
PreprocessorChainElement(preprocessor_id="builtin::basic"),
|
||||
PreprocessorChainElement(preprocessor_id="builtin::chunking"),
|
||||
]
|
||||
preprocessor_response = await self.preprocessing_api.chain_preprocess(
|
||||
preprocessors=preprocessor_chain, preprocessor_inputs=preprocessor_inputs
|
||||
)
|
||||
converted_inputs = conversion_response.results
|
||||
|
||||
chunking_response = await self.preprocessing_api.preprocess(
|
||||
preprocessor_id="builtin::chunking", preprocessor_inputs=converted_inputs
|
||||
)
|
||||
chunks = chunking_response.results
|
||||
chunks = preprocessor_response.results
|
||||
|
||||
if not chunks:
|
||||
return
|
||||
|
|
@ -197,13 +201,13 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def _rag_document_to_preprocessing_input(document: RAGDocument) -> PreprocessingInput:
|
||||
def _rag_document_to_preprocessor_input(document: RAGDocument) -> PreprocessorInput:
|
||||
if document.mime_type == "application/pdf":
|
||||
preprocessor_input_format = PreprocessingDataFormat.pdf
|
||||
else:
|
||||
preprocessor_input_format = None
|
||||
|
||||
return PreprocessingInput(
|
||||
return PreprocessorInput(
|
||||
preprocessor_input_id=document.document_id,
|
||||
preprocessor_input_format=preprocessor_input_format,
|
||||
path_or_content=document.content,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue