Added the preprocessing chain parameter to the RAG tool insert API.

This commit is contained in:
ilya-kolchinsky 2025-03-06 14:22:19 +01:00
parent 4c81a72214
commit 6cbc298edb
3 changed files with 13 additions and 2 deletions

View file

@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated, Protocol, runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.preprocessing import PreprocessorChain
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -79,6 +80,7 @@ class RAGToolRuntime(Protocol):
documents: List[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
preprocessor_chain: Optional[PreprocessorChain] = None,
) -> None:
"""Index documents so they can be used by the RAG system"""
...

View file

@ -458,9 +458,10 @@ class ToolRuntimeRouter(ToolRuntime):
documents: List[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
preprocessor_chain: Optional[PreprocessorChain] = None,
) -> None:
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens
documents, vector_db_id, chunk_size_in_tokens, preprocessor_chain
)
def __init__(

View file

@ -22,6 +22,7 @@ from llama_stack.apis.preprocessing import (
Preprocessing,
PreprocessingDataFormat,
PreprocessingDataType,
PreprocessorChain,
PreprocessorChainElement,
PreprocessorInput,
)
@ -49,6 +50,11 @@ def make_random_string(length: int = 8):
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
DEFAULT_PREPROCESSING_CHAIN = [
PreprocessorChainElement(preprocessor_id="builtin::basic"),
PreprocessorChainElement(preprocessor_id="builtin::chunking"),
]
def __init__(
self,
config: RagToolRuntimeConfig,
@ -72,6 +78,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
documents: List[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
preprocessor_chain: Optional[PreprocessorChain] = None,
) -> None:
preprocessor_inputs = [self._rag_document_to_preprocessor_input(d) for d in documents]
preprocessor_chain = [
@ -79,7 +86,8 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
PreprocessorChainElement(preprocessor_id="builtin::chunking"),
]
preprocessor_response = await self.preprocessing_api.chain_preprocess(
preprocessors=preprocessor_chain, preprocessor_inputs=preprocessor_inputs
preprocessors=preprocessor_chain or self.DEFAULT_PREPROCESSING_CHAIN,
preprocessor_inputs=preprocessor_inputs,
)
if not preprocessor_response.success: