mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
Added the preprocessing chain parameter to the RAG tool insert API.
This commit is contained in:
parent
4c81a72214
commit
6cbc298edb
3 changed files with 13 additions and 2 deletions
|
@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated, Protocol, runtime_checkable
|
from typing_extensions import Annotated, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||||
|
from llama_stack.apis.preprocessing import PreprocessorChain
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
@ -79,6 +80,7 @@ class RAGToolRuntime(Protocol):
|
||||||
documents: List[RAGDocument],
|
documents: List[RAGDocument],
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunk_size_in_tokens: int = 512,
|
chunk_size_in_tokens: int = 512,
|
||||||
|
preprocessor_chain: Optional[PreprocessorChain] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Index documents so they can be used by the RAG system"""
|
"""Index documents so they can be used by the RAG system"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -458,9 +458,10 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
documents: List[RAGDocument],
|
documents: List[RAGDocument],
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunk_size_in_tokens: int = 512,
|
chunk_size_in_tokens: int = 512,
|
||||||
|
preprocessor_chain: Optional[PreprocessorChain] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||||
documents, vector_db_id, chunk_size_in_tokens
|
documents, vector_db_id, chunk_size_in_tokens, preprocessor_chain
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -22,6 +22,7 @@ from llama_stack.apis.preprocessing import (
|
||||||
Preprocessing,
|
Preprocessing,
|
||||||
PreprocessingDataFormat,
|
PreprocessingDataFormat,
|
||||||
PreprocessingDataType,
|
PreprocessingDataType,
|
||||||
|
PreprocessorChain,
|
||||||
PreprocessorChainElement,
|
PreprocessorChainElement,
|
||||||
PreprocessorInput,
|
PreprocessorInput,
|
||||||
)
|
)
|
||||||
|
@ -49,6 +50,11 @@ def make_random_string(length: int = 8):
|
||||||
|
|
||||||
|
|
||||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
|
DEFAULT_PREPROCESSING_CHAIN = [
|
||||||
|
PreprocessorChainElement(preprocessor_id="builtin::basic"),
|
||||||
|
PreprocessorChainElement(preprocessor_id="builtin::chunking"),
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: RagToolRuntimeConfig,
|
config: RagToolRuntimeConfig,
|
||||||
|
@ -72,6 +78,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
documents: List[RAGDocument],
|
documents: List[RAGDocument],
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunk_size_in_tokens: int = 512,
|
chunk_size_in_tokens: int = 512,
|
||||||
|
preprocessor_chain: Optional[PreprocessorChain] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
preprocessor_inputs = [self._rag_document_to_preprocessor_input(d) for d in documents]
|
preprocessor_inputs = [self._rag_document_to_preprocessor_input(d) for d in documents]
|
||||||
preprocessor_chain = [
|
preprocessor_chain = [
|
||||||
|
@ -79,7 +86,8 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
PreprocessorChainElement(preprocessor_id="builtin::chunking"),
|
PreprocessorChainElement(preprocessor_id="builtin::chunking"),
|
||||||
]
|
]
|
||||||
preprocessor_response = await self.preprocessing_api.chain_preprocess(
|
preprocessor_response = await self.preprocessing_api.chain_preprocess(
|
||||||
preprocessors=preprocessor_chain, preprocessor_inputs=preprocessor_inputs
|
preprocessors=preprocessor_chain or self.DEFAULT_PREPROCESSING_CHAIN,
|
||||||
|
preprocessor_inputs=preprocessor_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not preprocessor_response.success:
|
if not preprocessor_response.success:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue