feat: Add synthetic-data-kit for file_search doc conversion

This adds a `builtin::document_conversion` tool for converting
documents when used with file_search that uses
meta-llama/synthetic-data-kit. I also have another local
implementation that uses Docling, but need to debug some segfault
issues I'm hitting locally with that so pushing this first as a
simpler reference implementation.

Long-term I think we'll want a remote implemention here as well - like
perhaps docling-serve or unstructured.io - but need to look more into
that.

This passes the existing
`tests/verifications/openai_api/test_responses.py` but doesn't yet add
any new tests for file types besides text and pdf.

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-06-20 18:09:14 -04:00
parent 6fde601765
commit e56690abef
18 changed files with 230 additions and 18 deletions

View file

@ -24,7 +24,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
| safety | `inline::llama-guard` | | safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `inline::synthetic-data-kit`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.providers.datatypes import Api
from .config import SyntheticDataKitToolRuntimeConfig
async def get_provider_impl(config: SyntheticDataKitToolRuntimeConfig, deps: dict[Api, Any]):
from .synthetic_data_kit import SyntheticDataKitToolRuntimeImpl
impl = SyntheticDataKitToolRuntimeImpl(config, deps[Api.files])
await impl.initialize()
return impl

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel
class SyntheticDataKitToolRuntimeConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {}

View file

@ -0,0 +1,117 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import logging
import mimetypes
import os
import tempfile
from typing import Any
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.files.files import Files
from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type
from .config import SyntheticDataKitToolRuntimeConfig
log = logging.getLogger(__name__)
class SyntheticDataKitToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime):
def __init__(
self,
config: SyntheticDataKitToolRuntimeConfig,
files_api: Files,
):
self.config = config
self.files_api = files_api
async def initialize(self):
pass
async def shutdown(self):
pass
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
ToolDef(
name="convert_file_to_text",
description="Convert a file to text",
parameters=[
ToolParameter(
name="file_id",
description="The id of the file to convert.",
parameter_type="string",
),
],
),
]
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
if tool_name != "convert_file_to_text":
raise ValueError(f"Unknown tool: {tool_name}")
file_id = kwargs["file_id"]
file_response = await self.files_api.openai_retrieve_file(file_id)
mime_type, _ = mimetypes.guess_type(file_response.filename)
content_response = await self.files_api.openai_retrieve_file_content(file_id)
mime_category = mime_type.split("/")[0] if mime_type else None
if mime_category == "text":
# Don't use synthetic-data-kit if the file is already text
content = content_from_data_and_mime_type(content_response.body, mime_type)
return ToolInvocationResult(
content=content,
metadata={},
)
else:
return await asyncio.to_thread(
self.synthetic_data_kit_convert, content_response.body, file_response.filename
)
def synthetic_data_kit_convert(self, content_body: bytes, filename: str) -> ToolInvocationResult:
from synthetic_data_kit.core.ingest import process_file
try:
with tempfile.TemporaryDirectory() as tmpdir:
file_path = os.path.join(tmpdir, filename)
with open(file_path, "wb") as f:
f.write(content_body)
output_path = process_file(file_path, tmpdir)
with open(output_path) as f:
content = f.read()
return ToolInvocationResult(
content=content,
metadata={},
)
except Exception as e:
return ToolInvocationResult(
content="",
error_message=f"Error converting file: {e}",
error_code=1,
metadata={},
)

View file

@ -16,6 +16,8 @@ async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None)) impl = FaissVectorIOAdapter(
config, deps[Api.inference], deps.get(Api.files, None), deps.get(Api.tool_runtime, None)
)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -18,6 +18,7 @@ from numpy.typing import NDArray
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.inference.inference import Inference from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.tools.tools import ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import ( from llama_stack.apis.vector_io import (
Chunk, Chunk,
@ -150,10 +151,17 @@ class FaissIndex(EmbeddingIndex):
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None: def __init__(
self,
config: FaissVectorIOConfig,
inference_api: Inference,
files_api: Files | None = None,
tool_runtime_api: ToolRuntime | None = None,
) -> None:
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.files_api = files_api self.files_api = files_api
self.tool_runtime_api = tool_runtime_api
self.cache: dict[str, VectorDBWithIndex] = {} self.cache: dict[str, VectorDBWithIndex] = {}
self.kvstore: KVStore | None = None self.kvstore: KVStore | None = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {} self.openai_vector_stores: dict[str, dict[str, Any]] = {}

View file

@ -15,6 +15,8 @@ async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
from .sqlite_vec import SQLiteVecVectorIOAdapter from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None)) impl = SQLiteVecVectorIOAdapter(
config, deps[Api.inference], deps.get(Api.files, None), deps.get(Api.tool_runtime, None)
)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -19,6 +19,7 @@ from numpy.typing import NDArray
from llama_stack.apis.files.files import Files from llama_stack.apis.files.files import Files
from llama_stack.apis.inference.inference import Inference from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.tools.tools import ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import ( from llama_stack.apis.vector_io import (
Chunk, Chunk,
@ -434,10 +435,13 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex). and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
""" """
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None: def __init__(
self, config, inference_api: Inference, files_api: Files | None, tool_runtime_api: ToolRuntime | None
) -> None:
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.files_api = files_api self.files_api = files_api
self.tool_runtime_api = tool_runtime_api
self.cache: dict[str, VectorDBWithIndex] = {} self.cache: dict[str, VectorDBWithIndex] = {}
self.openai_vector_stores: dict[str, dict[str, Any]] = {} self.openai_vector_stores: dict[str, dict[str, Any]] = {}

View file

@ -34,6 +34,14 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig", config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
api_dependencies=[Api.vector_io, Api.inference], api_dependencies=[Api.vector_io, Api.inference],
), ),
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::synthetic-data-kit",
pip_packages=["synthetic-data-kit"],
module="llama_stack.providers.inline.tool_runtime.synthetic-data-kit",
config_class="llama_stack.providers.inline.tool_runtime.synthetic-data-kit.config.SyntheticDataKitToolRuntimeConfig",
api_dependencies=[Api.files],
),
remote_provider_spec( remote_provider_spec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec( adapter=AdapterSpec(

View file

@ -24,7 +24,7 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig", config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
deprecation_warning="Please use the `inline::faiss` provider instead.", deprecation_warning="Please use the `inline::faiss` provider instead.",
api_dependencies=[Api.inference], api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files], optional_api_dependencies=[Api.files, Api.tool_runtime],
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.vector_io, api=Api.vector_io,
@ -33,7 +33,7 @@ def available_providers() -> list[ProviderSpec]:
module="llama_stack.providers.inline.vector_io.faiss", module="llama_stack.providers.inline.vector_io.faiss",
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig", config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
api_dependencies=[Api.inference], api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files], optional_api_dependencies=[Api.files, Api.tool_runtime],
), ),
# NOTE: sqlite-vec cannot be bundled into the container image because it does not have a # NOTE: sqlite-vec cannot be bundled into the container image because it does not have a
# source distribution and the wheels are not available for all platforms. # source distribution and the wheels are not available for all platforms.
@ -44,7 +44,7 @@ def available_providers() -> list[ProviderSpec]:
module="llama_stack.providers.inline.vector_io.sqlite_vec", module="llama_stack.providers.inline.vector_io.sqlite_vec",
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig", config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
api_dependencies=[Api.inference], api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files], optional_api_dependencies=[Api.files, Api.tool_runtime],
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.vector_io, api=Api.vector_io,
@ -54,7 +54,7 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig", config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.", deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
api_dependencies=[Api.inference], api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files], optional_api_dependencies=[Api.files, Api.tool_runtime],
), ),
remote_provider_spec( remote_provider_spec(
Api.vector_io, Api.vector_io,

View file

@ -6,14 +6,14 @@
import asyncio import asyncio
import logging import logging
import mimetypes
import time import time
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any, cast
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
from llama_stack.apis.files.files import OpenAIFileObject from llama_stack.apis.files.files import OpenAIFileObject
from llama_stack.apis.tools.tools import ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import ( from llama_stack.apis.vector_io import (
QueryChunksResponse, QueryChunksResponse,
@ -38,7 +38,7 @@ from llama_stack.apis.vector_io.vector_io import (
VectorStoreFileStatus, VectorStoreFileStatus,
VectorStoreListFilesResponse, VectorStoreListFilesResponse,
) )
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,6 +56,7 @@ class OpenAIVectorStoreMixin(ABC):
# These should be provided by the implementing class # These should be provided by the implementing class
openai_vector_stores: dict[str, dict[str, Any]] openai_vector_stores: dict[str, dict[str, Any]]
files_api: Files | None files_api: Files | None
tool_runtime_api: ToolRuntime | None
@abstractmethod @abstractmethod
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
@ -525,6 +526,14 @@ class OpenAIVectorStoreMixin(ABC):
) )
return vector_store_file_object return vector_store_file_object
if not hasattr(self, "tool_runtime_api") or not self.tool_runtime_api:
vector_store_file_object.status = "failed"
vector_store_file_object.last_error = VectorStoreFileLastError(
code="server_error",
message="Tool runtime API is not available",
)
return vector_store_file_object
if isinstance(chunking_strategy, VectorStoreChunkingStrategyStatic): if isinstance(chunking_strategy, VectorStoreChunkingStrategyStatic):
max_chunk_size_tokens = chunking_strategy.static.max_chunk_size_tokens max_chunk_size_tokens = chunking_strategy.static.max_chunk_size_tokens
chunk_overlap_tokens = chunking_strategy.static.chunk_overlap_tokens chunk_overlap_tokens = chunking_strategy.static.chunk_overlap_tokens
@ -534,12 +543,13 @@ class OpenAIVectorStoreMixin(ABC):
chunk_overlap_tokens = 400 chunk_overlap_tokens = 400
try: try:
file_response = await self.files_api.openai_retrieve_file(file_id) tool_result = await self.tool_runtime_api.invoke_tool(
mime_type, _ = mimetypes.guess_type(file_response.filename) "convert_file_to_text",
content_response = await self.files_api.openai_retrieve_file_content(file_id) {"file_id": file_id},
)
content = content_from_data_and_mime_type(content_response.body, mime_type) if tool_result.error_code or tool_result.error_message:
raise ValueError(f"Failed to convert file to text: {tool_result.error_message}")
content = cast(str, tool_result.content) # The tool always returns strings
chunks = make_overlapped_chunks( chunks = make_overlapped_chunks(
file_id, file_id,
content, content,

View file

@ -31,6 +31,7 @@ distribution_spec:
- remote::brave-search - remote::brave-search
- remote::tavily-search - remote::tavily-search
- inline::rag-runtime - inline::rag-runtime
- inline::synthetic-data-kit
- remote::model-context-protocol - remote::model-context-protocol
- remote::wolfram-alpha - remote::wolfram-alpha
image_type: conda image_type: conda

View file

@ -36,6 +36,7 @@ def get_distribution_template() -> DistributionTemplate:
"remote::brave-search", "remote::brave-search",
"remote::tavily-search", "remote::tavily-search",
"inline::rag-runtime", "inline::rag-runtime",
"inline::synthetic-data-kit",
"remote::model-context-protocol", "remote::model-context-protocol",
"remote::wolfram-alpha", "remote::wolfram-alpha",
], ],
@ -91,6 +92,10 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::wolfram_alpha", toolgroup_id="builtin::wolfram_alpha",
provider_id="wolfram-alpha", provider_id="wolfram-alpha",
), ),
ToolGroupInput(
toolgroup_id="builtin::document_conversion",
provider_id="synthetic-data-kit",
),
] ]
return DistributionTemplate( return DistributionTemplate(

View file

@ -114,6 +114,9 @@ providers:
- provider_id: rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
config: {} config: {}
- provider_id: synthetic-data-kit
provider_type: inline::synthetic-data-kit
config: {}
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
config: {} config: {}
@ -158,5 +161,7 @@ tool_groups:
provider_id: rag-runtime provider_id: rag-runtime
- toolgroup_id: builtin::wolfram_alpha - toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha provider_id: wolfram-alpha
- toolgroup_id: builtin::document_conversion
provider_id: synthetic-data-kit
server: server:
port: 8321 port: 8321

View file

@ -112,6 +112,9 @@ providers:
- provider_id: rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
config: {} config: {}
- provider_id: synthetic-data-kit
provider_type: inline::synthetic-data-kit
config: {}
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
config: {} config: {}
@ -148,5 +151,7 @@ tool_groups:
provider_id: rag-runtime provider_id: rag-runtime
- toolgroup_id: builtin::wolfram_alpha - toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha provider_id: wolfram-alpha
- toolgroup_id: builtin::document_conversion
provider_id: synthetic-data-kit
server: server:
port: 8321 port: 8321

View file

@ -38,6 +38,7 @@ distribution_spec:
- remote::brave-search - remote::brave-search
- remote::tavily-search - remote::tavily-search
- inline::rag-runtime - inline::rag-runtime
- inline::synthetic-data-kit
- remote::model-context-protocol - remote::model-context-protocol
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:

View file

@ -155,6 +155,9 @@ providers:
- provider_id: rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
config: {} config: {}
- provider_id: synthetic-data-kit
provider_type: inline::synthetic-data-kit
config: {}
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
config: {} config: {}
@ -954,5 +957,7 @@ tool_groups:
provider_id: tavily-search provider_id: tavily-search
- toolgroup_id: builtin::rag - toolgroup_id: builtin::rag
provider_id: rag-runtime provider_id: rag-runtime
- toolgroup_id: builtin::document_conversion
provider_id: synthetic-data-kit
server: server:
port: 8321 port: 8321

View file

@ -146,6 +146,7 @@ def get_distribution_template() -> DistributionTemplate:
"remote::brave-search", "remote::brave-search",
"remote::tavily-search", "remote::tavily-search",
"inline::rag-runtime", "inline::rag-runtime",
"inline::synthetic-data-kit",
"remote::model-context-protocol", "remote::model-context-protocol",
], ],
} }
@ -192,6 +193,10 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::rag", toolgroup_id="builtin::rag",
provider_id="rag-runtime", provider_id="rag-runtime",
), ),
ToolGroupInput(
toolgroup_id="builtin::document_conversion",
provider_id="synthetic-data-kit",
),
] ]
embedding_model = ModelInput( embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2", model_id="all-MiniLM-L6-v2",