diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index 6df93052c..2f75567d8 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -20,6 +20,7 @@ class Api(Enum): eval = "eval" post_training = "post_training" tool_runtime = "tool_runtime" + preprocessing = "preprocessing" telemetry = "telemetry" @@ -30,6 +31,7 @@ class Api(Enum): scoring_functions = "scoring_functions" benchmarks = "benchmarks" tool_groups = "tool_groups" + preprocessors = "preprocessors" # built-in API inspect = "inspect" diff --git a/llama_stack/apis/preprocessing/__init__.py b/llama_stack/apis/preprocessing/__init__.py new file mode 100644 index 000000000..ebd5936ba --- /dev/null +++ b/llama_stack/apis/preprocessing/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .preprocessing import * # noqa: F401 F403 diff --git a/llama_stack/apis/preprocessing/preprocessing.py b/llama_stack/apis/preprocessing/preprocessing.py new file mode 100644 index 000000000..a19f018e0 --- /dev/null +++ b/llama_stack/apis/preprocessing/preprocessing.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from enum import Enum +from typing import Any, Dict, List, Optional, Protocol, runtime_checkable + +from pydantic import BaseModel + +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.preprocessing.preprocessors import Preprocessor +from llama_stack.schema_utils import json_schema_type, webmethod + + +class PreprocessingInputType(Enum): + document_content = "document_content" + document_path = "document_path" + + +@json_schema_type +class PreprocessingInput(BaseModel): + preprocessor_input_id: str + preprocessor_input_type: Optional[PreprocessingInputType] + path_or_content: str | URL + + +PreprocessorOptions = Dict[str, Any] + +# TODO: shouldn't be just a string +PreprocessingResult = str + + +@json_schema_type +class PreprocessingResponse(BaseModel): + status: bool + results: Optional[List[str | PreprocessingResult]] + + +class PreprocessorStore(Protocol): + def get_preprocessor(self, preprocessor_id: str) -> Preprocessor: ... + + +@runtime_checkable +class Preprocessing(Protocol): + preprocessor_store: PreprocessorStore + + @webmethod(route="/preprocess", method="POST") + async def preprocess( + self, + preprocessor_id: str, + preprocessor_inputs: List[PreprocessingInput], + options: PreprocessorOptions, + ) -> PreprocessingResponse: ... diff --git a/llama_stack/apis/preprocessing/preprocessors.py b/llama_stack/apis/preprocessing/preprocessors.py new file mode 100644 index 000000000..dbe513b70 --- /dev/null +++ b/llama_stack/apis/preprocessing/preprocessors.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable + +from pydantic import BaseModel + +from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, webmethod + + +@json_schema_type +class Preprocessor(Resource): + type: Literal[ResourceType.preprocessor.value] = ResourceType.preprocessor.value + + @property + def preprocessor_id(self) -> str: + return self.identifier + + @property + def provider_preprocessor_id(self) -> str: + return self.provider_resource_id + + metadata: Optional[Dict[str, Any]] = None + + +class PreprocessorInput(BaseModel): + preprocessor_id: str + provider_id: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + + +class ListPreprocessorsResponse(BaseModel): + data: List[Preprocessor] + + +@runtime_checkable +@trace_protocol +class Preprocessors(Protocol): + @webmethod(route="/preprocessors", method="GET") + async def list_preprocessors(self) -> ListPreprocessorsResponse: ... + + @webmethod(route="/preprocessors/{preprocessor_id:path}", method="GET") + async def get_preprocessor( + self, + preprocessor_id: str, + ) -> Optional[Preprocessor]: ... + + @webmethod(route="/preprocessors", method="POST") + async def register_preprocessor( + self, + preprocessor_id: str, + provider_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Preprocessor: ... + + @webmethod(route="/preprocessors/{preprocessor_id:path}", method="DELETE") + async def unregister_preprocessor( + self, + preprocessor_id: str, + ) -> None: ... diff --git a/llama_stack/apis/resource.py b/llama_stack/apis/resource.py index 70ec63c55..e031092ef 100644 --- a/llama_stack/apis/resource.py +++ b/llama_stack/apis/resource.py @@ -18,6 +18,7 @@ class ResourceType(Enum): benchmark = "benchmark" tool = "tool" tool_group = "tool_group" + preprocessor = "preprocessor" class Resource(BaseModel): diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index f62996081..6ce6b4895 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -14,6 +14,8 @@ from llama_stack.apis.datasets import Dataset, DatasetInput from llama_stack.apis.eval import Eval from llama_stack.apis.inference import Inference from llama_stack.apis.models import Model, ModelInput +from llama_stack.apis.preprocessing import Preprocessing, Preprocessor +from llama_stack.apis.preprocessing.preprocessors import PreprocessorInput from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput @@ -40,6 +42,7 @@ RoutableObject = Union[ Benchmark, Tool, ToolGroup, + Preprocessor, ] @@ -53,6 +56,7 @@ RoutableObjectWithProvider = Annotated[ Benchmark, Tool, ToolGroup, + Preprocessor, ], Field(discriminator="type"), ] @@ -65,6 +69,7 @@ RoutedProtocol = Union[ Scoring, Eval, ToolRuntime, + Preprocessing, ] @@ -175,6 +180,7 @@ a default SQLite store will be used.""", scoring_fns: List[ScoringFnInput] = Field(default_factory=list) benchmarks: List[BenchmarkInput] = Field(default_factory=list) tool_groups: List[ToolGroupInput] = Field(default_factory=list) + preprocessors: List[PreprocessorInput] = Field(default_factory=list) server: ServerConfig = Field( default_factory=ServerConfig, diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 384e2c3c8..8905863d5 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -51,6 +51,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: routing_table_api=Api.tool_groups, router_api=Api.tool_runtime, ), + AutoRoutedApiInfo( + routing_table_api=Api.preprocessors, + router_api=Api.preprocessing, + ), ] diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 0bc2e774c..b8f9a9497 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -17,6 +17,8 @@ from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect from llama_stack.apis.models import Models from llama_stack.apis.post_training import PostTraining +from llama_stack.apis.preprocessing import Preprocessing +from llama_stack.apis.preprocessing.preprocessors import Preprocessors from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring_functions import ScoringFunctions @@ -41,6 +43,7 @@ from llama_stack.providers.datatypes import ( DatasetsProtocolPrivate, InlineProviderSpec, ModelsProtocolPrivate, + PreprocessorsProtocolPrivate, ProviderSpec, RemoteProviderConfig, RemoteProviderSpec, @@ -77,6 +80,8 @@ def api_protocol_map() -> Dict[Api, Any]: Api.post_training: PostTraining, Api.tool_groups: ToolGroups, Api.tool_runtime: ToolRuntime, + Api.preprocessing: Preprocessing, + Api.preprocessors: Preprocessors, } @@ -93,6 +98,7 @@ def additional_protocols_map() -> Dict[Api, Any]: Api.scoring_functions, ), Api.eval: (BenchmarksProtocolPrivate, Benchmarks, Api.benchmarks), + Api.preprocessing: (PreprocessorsProtocolPrivate, Preprocessors, Api.preprocessors), } diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index a54f57fb3..ad2cc7fcf 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -14,6 +14,7 @@ from .routing_tables import ( BenchmarksRoutingTable, DatasetsRoutingTable, ModelsRoutingTable, + PreprocessorsRoutingTable, ScoringFunctionsRoutingTable, ShieldsRoutingTable, ToolGroupsRoutingTable, @@ -35,6 +36,7 @@ async def get_routing_table_impl( "scoring_functions": ScoringFunctionsRoutingTable, "benchmarks": BenchmarksRoutingTable, "tool_groups": ToolGroupsRoutingTable, + "preprocessors": PreprocessorsRoutingTable, } if api.value not in api_to_tables: @@ -50,6 +52,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> DatasetIORouter, EvalRouter, InferenceRouter, + PreprocessingRouter, SafetyRouter, ScoringRouter, ToolRuntimeRouter, @@ -64,6 +67,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> "scoring": ScoringRouter, "eval": EvalRouter, "tool_runtime": ToolRuntimeRouter, + "preprocessing": PreprocessingRouter, } if api.value not in api_to_routers: raise ValueError(f"API {api.value} not found in router map") diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index b0cb50e42..561a67e27 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -34,6 +34,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import ModelType +from llama_stack.apis.preprocessing import Preprocessing, PreprocessingInput, PreprocessingResponse, PreprocessorOptions from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.scoring import ( ScoreBatchResponse, @@ -482,3 +483,28 @@ class ToolRuntimeRouter(ToolRuntime): self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) + + +class PreprocessingRouter(Preprocessing): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + self.routing_table = routing_table + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def preprocess( + self, + preprocessor_id: str, + preprocessor_inputs: List[PreprocessingInput], + options: PreprocessorOptions, + ) -> PreprocessingResponse: + return await self.routing_table.get_provider_impl(preprocessor_id).preprocess( + preprocessor_inputs=preprocessor_inputs, + options=options, + ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 80e9ecb7c..3d09d19ce 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -14,6 +14,7 @@ from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType +from llama_stack.apis.preprocessing.preprocessors import ListPreprocessorsResponse, Preprocessor, Preprocessors from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ( ListScoringFunctionsResponse, @@ -66,6 +67,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable return await p.register_benchmark(obj) elif api == Api.tool_runtime: return await p.register_tool(obj) + elif api == Api.preprocessing: + return await p.register_preprocessor(obj) else: raise ValueError(f"Unknown API {api} for registering object with provider") @@ -80,6 +83,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: return await p.unregister_dataset(obj.identifier) elif api == Api.tool_runtime: return await p.unregister_tool(obj.identifier) + elif api == Api.preprocessing: + return await p.unregister_preprocessor(obj.identifier) else: raise ValueError(f"Unregister not supported for {api}") @@ -127,6 +132,8 @@ class CommonRoutingTableImpl(RoutingTable): p.benchmark_store = self elif api == Api.tool_runtime: p.tool_store = self + elif api == Api.preprocessing: + p.preprocessor_store = self async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): @@ -148,6 +155,8 @@ class CommonRoutingTableImpl(RoutingTable): return ("Eval", "benchmark") elif isinstance(self, ToolGroupsRoutingTable): return ("Tools", "tool") + elif isinstance(self, PreprocessorsRoutingTable): + return ("Preprocessing", "preprocessor") else: raise ValueError("Unknown routing table type") @@ -536,3 +545,40 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): async def shutdown(self) -> None: pass + + +class PreprocessorsRoutingTable(CommonRoutingTableImpl, Preprocessors): + async def list_preprocessors(self) -> ListPreprocessorsResponse: + return ListPreprocessorsResponse(data=await self.get_all_with_type(ResourceType.preprocessor.value)) + + async def get_preprocessor(self, preprocessor_id: str) -> Optional[Preprocessor]: + return await self.get_object_by_identifier("preprocessor", preprocessor_id) + + async def register_preprocessor( + self, + preprocessor_id: str, + provider_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Preprocessor: + if provider_id is None: + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + preprocessor = Preprocessor( + identifier=preprocessor_id, + provider_resource_id=preprocessor_id, + provider_id=provider_id, + metadata=metadata, + ) + preprocessor.provider_id = provider_id + await self.register_object(preprocessor) + return preprocessor + + async def unregister_preprocessor(self, preprocessor_id: str) -> None: + existing_preprocessor = await self.get_preprocessor(preprocessor_id) + if existing_preprocessor is None: + raise ValueError(f"Preprocessor {preprocessor_id} not found") + await self.unregister_object(existing_preprocessor) diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 1328c88ef..47f732aa2 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -24,6 +24,8 @@ from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect from llama_stack.apis.models import Models from llama_stack.apis.post_training import PostTraining +from llama_stack.apis.preprocessing import Preprocessing +from llama_stack.apis.preprocessing.preprocessors import Preprocessors from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring_functions import ScoringFunctions @@ -65,6 +67,8 @@ class LlamaStack( ToolRuntime, RAGToolRuntime, Files, + Preprocessing, + Preprocessors, ): pass @@ -82,6 +86,7 @@ RESOURCES = [ ), ("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks"), ("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"), + ("preprocessors", Api.preprocessors, "register_preprocessor", "list_preprocessors"), ] diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 384582423..f34da79c0 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -13,6 +13,7 @@ from llama_stack.apis.benchmarks import Benchmark from llama_stack.apis.datasets import Dataset from llama_stack.apis.datatypes import Api from llama_stack.apis.models import Model +from llama_stack.apis.preprocessing import Preprocessor from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.shields import Shield from llama_stack.apis.tools import Tool @@ -58,6 +59,12 @@ class ToolsProtocolPrivate(Protocol): async def unregister_tool(self, tool_id: str) -> None: ... +class PreprocessorsProtocolPrivate(Protocol): + async def register_preprocessor(self, preprocessor: Preprocessor) -> None: ... + + async def unregister_preprocessor(self, preprocessor_id: str) -> None: ... + + @json_schema_type class ProviderSpec(BaseModel): api: Api diff --git a/llama_stack/providers/inline/preprocessing/__init__.py b/llama_stack/providers/inline/preprocessing/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/preprocessing/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/preprocessing/docling/__init__.py b/llama_stack/providers/inline/preprocessing/docling/__init__.py new file mode 100644 index 000000000..15eeccb71 --- /dev/null +++ b/llama_stack/providers/inline/preprocessing/docling/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import InlineDoclingConfig + + +async def get_provider_impl( + config: InlineDoclingConfig, + _deps, +): + from .docling import InclineDoclingPreprocessorImpl + + impl = InclineDoclingPreprocessorImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/preprocessing/docling/config.py b/llama_stack/providers/inline/preprocessing/docling/config.py new file mode 100644 index 000000000..9527bd6fc --- /dev/null +++ b/llama_stack/providers/inline/preprocessing/docling/config.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from pydantic import BaseModel + + +class InlineDoclingConfig(BaseModel): ... diff --git a/llama_stack/providers/inline/preprocessing/docling/docling.py b/llama_stack/providers/inline/preprocessing/docling/docling.py new file mode 100644 index 000000000..a3794f2f8 --- /dev/null +++ b/llama_stack/providers/inline/preprocessing/docling/docling.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import List + +from llama_stack.apis.preprocessing import ( + Preprocessing, + PreprocessingInput, + PreprocessingResponse, + Preprocessor, + PreprocessorOptions, +) +from llama_stack.providers.datatypes import PreprocessorsProtocolPrivate +from llama_stack.providers.inline.preprocessing.docling import InlineDoclingConfig + + +class InclineDoclingPreprocessorImpl(Preprocessing, PreprocessorsProtocolPrivate): + def __init__(self, config: InlineDoclingConfig) -> None: + self.config = config + + async def initialize(self) -> None: ... + + async def shutdown(self) -> None: ... + + async def register_preprocessor(self, preprocessor: Preprocessor) -> None: ... + + async def unregister_preprocessor(self, preprocessor_id: str) -> None: ... + + async def preprocess( + self, + preprocessor_id: str, + preprocessor_inputs: List[PreprocessingInput], + options: PreprocessorOptions, + ) -> PreprocessingResponse: ... diff --git a/llama_stack/providers/registry/preprocessing.py b/llama_stack/providers/registry/preprocessing.py new file mode 100644 index 000000000..9eaff312a --- /dev/null +++ b/llama_stack/providers/registry/preprocessing.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.providers.datatypes import ( + Api, + InlineProviderSpec, + ProviderSpec, +) + + +def available_providers() -> List[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.preprocessing, + provider_type="inline::docling", + pip_packages=["docling"], + module="llama_stack.providers.inline.preprocessing.docling", + config_class="llama_stack.providers.inline.preprocessing.docling.InlineDoclingConfig", + api_dependencies=[], + ), + ]