mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-01 20:18:50 +00:00
feat: Add synthetic-data-kit for file_search doc conversion
This adds a `builtin::document_conversion` tool for converting documents when used with file_search that uses meta-llama/synthetic-data-kit. I also have another local implementation that uses Docling, but need to debug some segfault issues I'm hitting locally with that so pushing this first as a simpler reference implementation. Long-term I think we'll want a remote implemention here as well - like perhaps docling-serve or unstructured.io - but need to look more into that. This passes the existing `tests/verifications/openai_api/test_responses.py` but doesn't yet add any new tests for file types besides text and pdf. Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
6fde601765
commit
e56690abef
18 changed files with 230 additions and 18 deletions
|
@ -24,7 +24,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
|||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `inline::synthetic-data-kit`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import SyntheticDataKitToolRuntimeConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: SyntheticDataKitToolRuntimeConfig, deps: dict[Api, Any]):
|
||||
from .synthetic_data_kit import SyntheticDataKitToolRuntimeImpl
|
||||
|
||||
impl = SyntheticDataKitToolRuntimeImpl(config, deps[Api.files])
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SyntheticDataKitToolRuntimeConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
|
@ -0,0 +1,117 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.files.files import Files
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type
|
||||
|
||||
from .config import SyntheticDataKitToolRuntimeConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SyntheticDataKitToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
config: SyntheticDataKitToolRuntimeConfig,
|
||||
files_api: Files,
|
||||
):
|
||||
self.config = config
|
||||
self.files_api = files_api
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
ToolDef(
|
||||
name="convert_file_to_text",
|
||||
description="Convert a file to text",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="file_id",
|
||||
description="The id of the file to convert.",
|
||||
parameter_type="string",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
if tool_name != "convert_file_to_text":
|
||||
raise ValueError(f"Unknown tool: {tool_name}")
|
||||
|
||||
file_id = kwargs["file_id"]
|
||||
file_response = await self.files_api.openai_retrieve_file(file_id)
|
||||
mime_type, _ = mimetypes.guess_type(file_response.filename)
|
||||
content_response = await self.files_api.openai_retrieve_file_content(file_id)
|
||||
|
||||
mime_category = mime_type.split("/")[0] if mime_type else None
|
||||
if mime_category == "text":
|
||||
# Don't use synthetic-data-kit if the file is already text
|
||||
content = content_from_data_and_mime_type(content_response.body, mime_type)
|
||||
return ToolInvocationResult(
|
||||
content=content,
|
||||
metadata={},
|
||||
)
|
||||
else:
|
||||
return await asyncio.to_thread(
|
||||
self.synthetic_data_kit_convert, content_response.body, file_response.filename
|
||||
)
|
||||
|
||||
def synthetic_data_kit_convert(self, content_body: bytes, filename: str) -> ToolInvocationResult:
|
||||
from synthetic_data_kit.core.ingest import process_file
|
||||
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
file_path = os.path.join(tmpdir, filename)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content_body)
|
||||
output_path = process_file(file_path, tmpdir)
|
||||
with open(output_path) as f:
|
||||
content = f.read()
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=content,
|
||||
metadata={},
|
||||
)
|
||||
except Exception as e:
|
||||
return ToolInvocationResult(
|
||||
content="",
|
||||
error_message=f"Error converting file: {e}",
|
||||
error_code=1,
|
||||
metadata={},
|
||||
)
|
|
@ -16,6 +16,8 @@ async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
|
|||
|
||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||
impl = FaissVectorIOAdapter(
|
||||
config, deps[Api.inference], deps.get(Api.files, None), deps.get(Api.tool_runtime, None)
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -18,6 +18,7 @@ from numpy.typing import NDArray
|
|||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.apis.tools.tools import ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
|
@ -150,10 +151,17 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
config: FaissVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
files_api: Files | None = None,
|
||||
tool_runtime_api: ToolRuntime | None = None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
|
|
|
@ -15,6 +15,8 @@ async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
|
|||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||
|
||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||
impl = SQLiteVecVectorIOAdapter(
|
||||
config, deps[Api.inference], deps.get(Api.files, None), deps.get(Api.tool_runtime, None)
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -19,6 +19,7 @@ from numpy.typing import NDArray
|
|||
|
||||
from llama_stack.apis.files.files import Files
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.apis.tools.tools import ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
|
@ -434,10 +435,13 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||
"""
|
||||
|
||||
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
|
||||
def __init__(
|
||||
self, config, inference_api: Inference, files_api: Files | None, tool_runtime_api: ToolRuntime | None
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
|
|
@ -34,6 +34,14 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
|
||||
api_dependencies=[Api.vector_io, Api.inference],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
provider_type="inline::synthetic-data-kit",
|
||||
pip_packages=["synthetic-data-kit"],
|
||||
module="llama_stack.providers.inline.tool_runtime.synthetic-data-kit",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.synthetic-data-kit.config.SyntheticDataKitToolRuntimeConfig",
|
||||
api_dependencies=[Api.files],
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
@ -24,7 +24,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.tool_runtime],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
|
@ -33,7 +33,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.vector_io.faiss",
|
||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.tool_runtime],
|
||||
),
|
||||
# NOTE: sqlite-vec cannot be bundled into the container image because it does not have a
|
||||
# source distribution and the wheels are not available for all platforms.
|
||||
|
@ -44,7 +44,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.tool_runtime],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
|
@ -54,7 +54,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.tool_runtime],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
|
|
|
@ -6,14 +6,14 @@
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.files.files import OpenAIFileObject
|
||||
from llama_stack.apis.tools.tools import ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
QueryChunksResponse,
|
||||
|
@ -38,7 +38,7 @@ from llama_stack.apis.vector_io.vector_io import (
|
|||
VectorStoreFileStatus,
|
||||
VectorStoreListFilesResponse,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
|
||||
from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -56,6 +56,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
# These should be provided by the implementing class
|
||||
openai_vector_stores: dict[str, dict[str, Any]]
|
||||
files_api: Files | None
|
||||
tool_runtime_api: ToolRuntime | None
|
||||
|
||||
@abstractmethod
|
||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
|
@ -525,6 +526,14 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
)
|
||||
return vector_store_file_object
|
||||
|
||||
if not hasattr(self, "tool_runtime_api") or not self.tool_runtime_api:
|
||||
vector_store_file_object.status = "failed"
|
||||
vector_store_file_object.last_error = VectorStoreFileLastError(
|
||||
code="server_error",
|
||||
message="Tool runtime API is not available",
|
||||
)
|
||||
return vector_store_file_object
|
||||
|
||||
if isinstance(chunking_strategy, VectorStoreChunkingStrategyStatic):
|
||||
max_chunk_size_tokens = chunking_strategy.static.max_chunk_size_tokens
|
||||
chunk_overlap_tokens = chunking_strategy.static.chunk_overlap_tokens
|
||||
|
@ -534,12 +543,13 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
chunk_overlap_tokens = 400
|
||||
|
||||
try:
|
||||
file_response = await self.files_api.openai_retrieve_file(file_id)
|
||||
mime_type, _ = mimetypes.guess_type(file_response.filename)
|
||||
content_response = await self.files_api.openai_retrieve_file_content(file_id)
|
||||
|
||||
content = content_from_data_and_mime_type(content_response.body, mime_type)
|
||||
|
||||
tool_result = await self.tool_runtime_api.invoke_tool(
|
||||
"convert_file_to_text",
|
||||
{"file_id": file_id},
|
||||
)
|
||||
if tool_result.error_code or tool_result.error_message:
|
||||
raise ValueError(f"Failed to convert file to text: {tool_result.error_message}")
|
||||
content = cast(str, tool_result.content) # The tool always returns strings
|
||||
chunks = make_overlapped_chunks(
|
||||
file_id,
|
||||
content,
|
||||
|
|
|
@ -31,6 +31,7 @@ distribution_spec:
|
|||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::rag-runtime
|
||||
- inline::synthetic-data-kit
|
||||
- remote::model-context-protocol
|
||||
- remote::wolfram-alpha
|
||||
image_type: conda
|
||||
|
|
|
@ -36,6 +36,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
"inline::rag-runtime",
|
||||
"inline::synthetic-data-kit",
|
||||
"remote::model-context-protocol",
|
||||
"remote::wolfram-alpha",
|
||||
],
|
||||
|
@ -91,6 +92,10 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
toolgroup_id="builtin::wolfram_alpha",
|
||||
provider_id="wolfram-alpha",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::document_conversion",
|
||||
provider_id="synthetic-data-kit",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
|
|
|
@ -114,6 +114,9 @@ providers:
|
|||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: synthetic-data-kit
|
||||
provider_type: inline::synthetic-data-kit
|
||||
config: {}
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
|
@ -158,5 +161,7 @@ tool_groups:
|
|||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::wolfram_alpha
|
||||
provider_id: wolfram-alpha
|
||||
- toolgroup_id: builtin::document_conversion
|
||||
provider_id: synthetic-data-kit
|
||||
server:
|
||||
port: 8321
|
||||
|
|
|
@ -112,6 +112,9 @@ providers:
|
|||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: synthetic-data-kit
|
||||
provider_type: inline::synthetic-data-kit
|
||||
config: {}
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
|
@ -148,5 +151,7 @@ tool_groups:
|
|||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::wolfram_alpha
|
||||
provider_id: wolfram-alpha
|
||||
- toolgroup_id: builtin::document_conversion
|
||||
provider_id: synthetic-data-kit
|
||||
server:
|
||||
port: 8321
|
||||
|
|
|
@ -38,6 +38,7 @@ distribution_spec:
|
|||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::rag-runtime
|
||||
- inline::synthetic-data-kit
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
|
|
|
@ -155,6 +155,9 @@ providers:
|
|||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: synthetic-data-kit
|
||||
provider_type: inline::synthetic-data-kit
|
||||
config: {}
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
|
@ -954,5 +957,7 @@ tool_groups:
|
|||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::document_conversion
|
||||
provider_id: synthetic-data-kit
|
||||
server:
|
||||
port: 8321
|
||||
|
|
|
@ -146,6 +146,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
"inline::rag-runtime",
|
||||
"inline::synthetic-data-kit",
|
||||
"remote::model-context-protocol",
|
||||
],
|
||||
}
|
||||
|
@ -192,6 +193,10 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::document_conversion",
|
||||
provider_id="synthetic-data-kit",
|
||||
),
|
||||
]
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue