Fixed multiple bugs.

This commit is contained in:
ilya-kolchinsky 2025-03-06 16:46:59 +01:00
parent 6cbc298edb
commit f10a412898
7 changed files with 102 additions and 78 deletions

View file

@ -20,11 +20,11 @@ from llama_stack.apis.common.content_types import (
from llama_stack.apis.inference import Inference
from llama_stack.apis.preprocessing import (
Preprocessing,
PreprocessingDataElement,
PreprocessingDataFormat,
PreprocessingDataType,
PreprocessorChain,
PreprocessorChainElement,
PreprocessorInput,
)
from llama_stack.apis.tools import (
RAGDocument,
@ -81,10 +81,6 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
preprocessor_chain: Optional[PreprocessorChain] = None,
) -> None:
preprocessor_inputs = [self._rag_document_to_preprocessor_input(d) for d in documents]
preprocessor_chain = [
PreprocessorChainElement(preprocessor_id="builtin::basic"),
PreprocessorChainElement(preprocessor_id="builtin::chunking"),
]
preprocessor_response = await self.preprocessing_api.chain_preprocess(
preprocessors=preprocessor_chain or self.DEFAULT_PREPROCESSING_CHAIN,
preprocessor_inputs=preprocessor_inputs,
@ -94,9 +90,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
log.error("Preprocessor chain returned an error")
return
if preprocessor_response.preprocessor_output_type != PreprocessingDataType.chunks:
if preprocessor_response.output_data_type != PreprocessingDataType.chunks:
log.error(
f"Preprocessor chain returned {preprocessor_response.preprocessor_output_type} instead of {PreprocessingDataType.chunks}"
f"Preprocessor chain returned {preprocessor_response.output_data_type} instead of {PreprocessingDataType.chunks}"
)
return
@ -105,8 +101,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
log.error("No chunks returned by the preprocessor chain")
return
actual_chunks = [chunk.data_element_path_or_content for chunk in chunks]
await self.vector_io_api.insert_chunks(
chunks=chunks,
chunks=actual_chunks,
vector_db_id=vector_db_id,
)
@ -220,14 +217,14 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
)
@staticmethod
def _rag_document_to_preprocessor_input(document: RAGDocument) -> PreprocessorInput:
def _rag_document_to_preprocessor_input(document: RAGDocument) -> PreprocessingDataElement:
if document.mime_type == "application/pdf":
preprocessor_input_format = PreprocessingDataFormat.pdf
data_element_format = PreprocessingDataFormat.pdf
else:
preprocessor_input_format = None
data_element_format = None
return PreprocessorInput(
preprocessor_input_id=document.document_id,
preprocessor_input_format=preprocessor_input_format,
path_or_content=document.content,
return PreprocessingDataElement(
data_element_id=document.document_id,
data_element_format=data_element_format,
data_element_path_or_content=document.content,
)