Initial implementation of RAG operator using the preprocessing endpoint.

This commit is contained in:
ilya-kolchinsky 2025-03-05 13:43:26 +01:00
parent c2bd31eb5c
commit 16764a2f06
8 changed files with 74 additions and 37 deletions

View file

@ -14,6 +14,6 @@ from .config import RagToolRuntimeConfig
async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[str, Any]):
from .memory import MemoryToolRuntimeImpl
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.preprocessing])
await impl.initialize()
return impl

View file

@ -18,6 +18,7 @@ from llama_stack.apis.common.content_types import (
TextContentItem,
)
from llama_stack.apis.inference import Inference
from llama_stack.apis.preprocessing import Preprocessing, PreprocessingDataFormat, PreprocessingInput
from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
@ -30,10 +31,6 @@ from llama_stack.apis.tools import (
)
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
make_overlapped_chunks,
)
from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query
@ -51,10 +48,12 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
config: RagToolRuntimeConfig,
vector_io_api: VectorIO,
inference_api: Inference,
preprocessing_api: Preprocessing,
):
self.config = config
self.vector_io_api = vector_io_api
self.inference_api = inference_api
self.preprocessing_api = preprocessing_api
async def initialize(self):
pass
@ -68,17 +67,17 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
chunks = []
for doc in documents:
content = await content_from_doc(doc)
chunks.extend(
make_overlapped_chunks(
doc.document_id,
content,
chunk_size_in_tokens,
chunk_size_in_tokens // 4,
)
)
preprocessing_inputs = [self._rag_document_to_preprocessing_input(d) for d in documents]
conversion_response = await self.preprocessing_api.preprocess(
preprocessor_id="builtin::basic", preprocessor_inputs=preprocessing_inputs
)
converted_inputs = conversion_response.results
chunking_response = await self.preprocessing_api.preprocess(
preprocessor_id="builtin::chunking", preprocessor_inputs=converted_inputs
)
chunks = chunking_response.results
if not chunks:
return
@ -196,3 +195,16 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
content=result.content,
metadata=result.metadata,
)
@staticmethod
def _rag_document_to_preprocessing_input(document: RAGDocument) -> PreprocessingInput:
if document.mime_type == "application/pdf":
preprocessor_input_format = PreprocessingDataFormat.pdf
else:
preprocessor_input_format = None
return PreprocessingInput(
preprocessor_input_id=document.document_id,
preprocessor_input_format=preprocessor_input_format,
path_or_content=document.content,
)