Added the preprocessing chain parameter to the RAG tool insert API.

This commit is contained in:
ilya-kolchinsky 2025-03-06 14:22:19 +01:00
parent 4c81a72214
commit 6cbc298edb
3 changed files with 13 additions and 2 deletions

View file

@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated, Protocol, runtime_checkable from typing_extensions import Annotated, Protocol, runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.preprocessing import PreprocessorChain
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -79,6 +80,7 @@ class RAGToolRuntime(Protocol):
documents: List[RAGDocument], documents: List[RAGDocument],
vector_db_id: str, vector_db_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
preprocessor_chain: Optional[PreprocessorChain] = None,
) -> None: ) -> None:
"""Index documents so they can be used by the RAG system""" """Index documents so they can be used by the RAG system"""
... ...

View file

@ -458,9 +458,10 @@ class ToolRuntimeRouter(ToolRuntime):
documents: List[RAGDocument], documents: List[RAGDocument],
vector_db_id: str, vector_db_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
preprocessor_chain: Optional[PreprocessorChain] = None,
) -> None: ) -> None:
return await self.routing_table.get_provider_impl("insert_into_memory").insert( return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens documents, vector_db_id, chunk_size_in_tokens, preprocessor_chain
) )
def __init__( def __init__(

View file

@ -22,6 +22,7 @@ from llama_stack.apis.preprocessing import (
Preprocessing, Preprocessing,
PreprocessingDataFormat, PreprocessingDataFormat,
PreprocessingDataType, PreprocessingDataType,
PreprocessorChain,
PreprocessorChainElement, PreprocessorChainElement,
PreprocessorInput, PreprocessorInput,
) )
@ -49,6 +50,11 @@ def make_random_string(length: int = 8):
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
DEFAULT_PREPROCESSING_CHAIN = [
PreprocessorChainElement(preprocessor_id="builtin::basic"),
PreprocessorChainElement(preprocessor_id="builtin::chunking"),
]
def __init__( def __init__(
self, self,
config: RagToolRuntimeConfig, config: RagToolRuntimeConfig,
@ -72,6 +78,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
documents: List[RAGDocument], documents: List[RAGDocument],
vector_db_id: str, vector_db_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
preprocessor_chain: Optional[PreprocessorChain] = None,
) -> None: ) -> None:
preprocessor_inputs = [self._rag_document_to_preprocessor_input(d) for d in documents] preprocessor_inputs = [self._rag_document_to_preprocessor_input(d) for d in documents]
preprocessor_chain = [ preprocessor_chain = [
@ -79,7 +86,8 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
PreprocessorChainElement(preprocessor_id="builtin::chunking"), PreprocessorChainElement(preprocessor_id="builtin::chunking"),
] ]
preprocessor_response = await self.preprocessing_api.chain_preprocess( preprocessor_response = await self.preprocessing_api.chain_preprocess(
preprocessors=preprocessor_chain, preprocessor_inputs=preprocessor_inputs preprocessors=preprocessor_chain or self.DEFAULT_PREPROCESSING_CHAIN,
preprocessor_inputs=preprocessor_inputs,
) )
if not preprocessor_response.success: if not preprocessor_response.success: