mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
Added the preprocessing chain parameter to the RAG tool insert API.
This commit is contained in:
parent
4c81a72214
commit
6cbc298edb
3 changed files with 13 additions and 2 deletions
|
@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
|
|||
from typing_extensions import Annotated, Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.apis.preprocessing import PreprocessorChain
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
@ -79,6 +80,7 @@ class RAGToolRuntime(Protocol):
|
|||
documents: List[RAGDocument],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
preprocessor_chain: Optional[PreprocessorChain] = None,
|
||||
) -> None:
|
||||
"""Index documents so they can be used by the RAG system"""
|
||||
...
|
||||
|
|
|
@ -458,9 +458,10 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
documents: List[RAGDocument],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
preprocessor_chain: Optional[PreprocessorChain] = None,
|
||||
) -> None:
|
||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||
documents, vector_db_id, chunk_size_in_tokens
|
||||
documents, vector_db_id, chunk_size_in_tokens, preprocessor_chain
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -22,6 +22,7 @@ from llama_stack.apis.preprocessing import (
|
|||
Preprocessing,
|
||||
PreprocessingDataFormat,
|
||||
PreprocessingDataType,
|
||||
PreprocessorChain,
|
||||
PreprocessorChainElement,
|
||||
PreprocessorInput,
|
||||
)
|
||||
|
@ -49,6 +50,11 @@ def make_random_string(length: int = 8):
|
|||
|
||||
|
||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||
DEFAULT_PREPROCESSING_CHAIN = [
|
||||
PreprocessorChainElement(preprocessor_id="builtin::basic"),
|
||||
PreprocessorChainElement(preprocessor_id="builtin::chunking"),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RagToolRuntimeConfig,
|
||||
|
@ -72,6 +78,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
documents: List[RAGDocument],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
preprocessor_chain: Optional[PreprocessorChain] = None,
|
||||
) -> None:
|
||||
preprocessor_inputs = [self._rag_document_to_preprocessor_input(d) for d in documents]
|
||||
preprocessor_chain = [
|
||||
|
@ -79,7 +86,8 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
PreprocessorChainElement(preprocessor_id="builtin::chunking"),
|
||||
]
|
||||
preprocessor_response = await self.preprocessing_api.chain_preprocess(
|
||||
preprocessors=preprocessor_chain, preprocessor_inputs=preprocessor_inputs
|
||||
preprocessors=preprocessor_chain or self.DEFAULT_PREPROCESSING_CHAIN,
|
||||
preprocessor_inputs=preprocessor_inputs,
|
||||
)
|
||||
|
||||
if not preprocessor_response.success:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue