mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 05:00:00 +00:00
Initial implementation of RAG operator using the preprocessing endpoint.
This commit is contained in:
parent
c2bd31eb5c
commit
16764a2f06
8 changed files with 74 additions and 37 deletions
|
|
@ -14,6 +14,6 @@ from .config import RagToolRuntimeConfig
|
|||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[str, Any]):
|
||||
from .memory import MemoryToolRuntimeImpl
|
||||
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.preprocessing])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from llama_stack.apis.common.content_types import (
|
|||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.preprocessing import Preprocessing, PreprocessingDataFormat, PreprocessingInput
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
|
|
@ -30,10 +31,6 @@ from llama_stack.apis.tools import (
|
|||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
content_from_doc,
|
||||
make_overlapped_chunks,
|
||||
)
|
||||
|
||||
from .config import RagToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
|
|
@ -51,10 +48,12 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
config: RagToolRuntimeConfig,
|
||||
vector_io_api: VectorIO,
|
||||
inference_api: Inference,
|
||||
preprocessing_api: Preprocessing,
|
||||
):
|
||||
self.config = config
|
||||
self.vector_io_api = vector_io_api
|
||||
self.inference_api = inference_api
|
||||
self.preprocessing_api = preprocessing_api
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
|
@ -68,17 +67,17 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
chunks = []
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
chunks.extend(
|
||||
make_overlapped_chunks(
|
||||
doc.document_id,
|
||||
content,
|
||||
chunk_size_in_tokens,
|
||||
chunk_size_in_tokens // 4,
|
||||
)
|
||||
)
|
||||
preprocessing_inputs = [self._rag_document_to_preprocessing_input(d) for d in documents]
|
||||
|
||||
conversion_response = await self.preprocessing_api.preprocess(
|
||||
preprocessor_id="builtin::basic", preprocessor_inputs=preprocessing_inputs
|
||||
)
|
||||
converted_inputs = conversion_response.results
|
||||
|
||||
chunking_response = await self.preprocessing_api.preprocess(
|
||||
preprocessor_id="builtin::chunking", preprocessor_inputs=converted_inputs
|
||||
)
|
||||
chunks = chunking_response.results
|
||||
|
||||
if not chunks:
|
||||
return
|
||||
|
|
@ -196,3 +195,16 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
content=result.content,
|
||||
metadata=result.metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _rag_document_to_preprocessing_input(document: RAGDocument) -> PreprocessingInput:
|
||||
if document.mime_type == "application/pdf":
|
||||
preprocessor_input_format = PreprocessingDataFormat.pdf
|
||||
else:
|
||||
preprocessor_input_format = None
|
||||
|
||||
return PreprocessingInput(
|
||||
preprocessor_input_id=document.document_id,
|
||||
preprocessor_input_format=preprocessor_input_format,
|
||||
path_or_content=document.content,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue