mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-05 05:35:22 +00:00
(feat) Add synthetic_data_kit provider integration for synthetic_data_generation API
The synthetic_data_kit provider integration enables high-quality synthetic dataset generation for fine-tuning LLMs. This commit sets up the initial provider registration and fixes provider resolution to properly handle type casting and imports, ensuring proper integration with llama-stack's provider system. Implementation of the actual provider functionality will follow in a subsequent commit. Signed-off-by: Alina Ryan <aliryan@redhat.com>
This commit is contained in:
parent
e867501073
commit
f86f107f15
7 changed files with 209 additions and 13 deletions
|
@ -35,6 +35,8 @@ class Api(Enum):
|
||||||
tool_groups = "tool_groups"
|
tool_groups = "tool_groups"
|
||||||
files = "files"
|
files = "files"
|
||||||
|
|
||||||
|
synthetic_data_generation = "synthetic_data_generation"
|
||||||
|
|
||||||
# built-in API
|
# built-in API
|
||||||
inspect = "inspect"
|
inspect = "inspect"
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||||
|
from llama_stack.providers.utils.resolver import get_provider_impl as _get_provider_impl
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_impl() -> SyntheticDataGeneration:
|
||||||
|
return cast(SyntheticDataGeneration, _get_provider_impl(SyntheticDataGeneration))
|
|
@ -1,13 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
|
||||||
|
|
||||||
SYNTHETIC_DATA_GENERATION_PROVIDERS: dict[str, SyntheticDataGeneration] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def get_provider(name: str = "meta_synthetic_data_kit") -> SyntheticDataGeneration:
|
|
||||||
raise NotImplementedError(f"No provider registered yet for synthetic_data_generation (requested: {name})")
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -0,0 +1,100 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Message
|
||||||
|
from llama_stack.apis.synthetic_data_generation import (
|
||||||
|
FilteringFunction,
|
||||||
|
SyntheticDataGeneration,
|
||||||
|
SyntheticDataGenerationResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SyntheticDataKitConfig(BaseModel):
|
||||||
|
"""Configuration for the Synthetic Data Kit provider"""
|
||||||
|
llm: Dict[str, Any] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"provider": "vllm",
|
||||||
|
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
vllm: Dict[str, Any] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"api_base": "http://localhost:8000/v1",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
generation: Dict[str, Any] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"temperature": 0.7,
|
||||||
|
"chunk_size": 4000,
|
||||||
|
"num_pairs": 25,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
curate: Dict[str, Any] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"threshold": 7.0,
|
||||||
|
"batch_size": 8,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls) -> "SyntheticDataKitConfig":
|
||||||
|
"""Create a sample configuration for testing"""
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
|
class SyntheticDataKitProvider(SyntheticDataGeneration):
|
||||||
|
def __init__(self, config: SyntheticDataKitConfig):
|
||||||
|
self.config = config
|
||||||
|
self._validate_connection()
|
||||||
|
|
||||||
|
def _validate_connection(self) -> None:
|
||||||
|
"""Validate connection to vLLM server"""
|
||||||
|
try:
|
||||||
|
response = requests.get(f"http://localhost:{self.config.vllm['port']}/health")
|
||||||
|
response.raise_for_status()
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to connect to vLLM server: {e}") from e
|
||||||
|
|
||||||
|
def synthetic_data_generate(
|
||||||
|
self,
|
||||||
|
dialogs: list[Message],
|
||||||
|
filtering_function: FilteringFunction = FilteringFunction.none,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> SyntheticDataGenerationResponse:
|
||||||
|
# Convert dialogs to SDK format
|
||||||
|
formatted_dialogs = [{"role": dialog.role, "content": dialog.content} for dialog in dialogs]
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"dialogs": formatted_dialogs,
|
||||||
|
"filtering_function": filtering_function.value,
|
||||||
|
"model": model or self.config.llm["model"],
|
||||||
|
"generation": self.config.generation,
|
||||||
|
"curate": self.config.curate if filtering_function != FilteringFunction.none else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"http://localhost:{self.config.vllm['port']}/v1/synthetic-data-generation/generate",
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
return SyntheticDataGenerationResponse(
|
||||||
|
synthetic_data=result.get("synthetic_data", []),
|
||||||
|
statistics=result.get("statistics"),
|
||||||
|
)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
raise RuntimeError(f"Synthetic data generation failed: {e}") from e
|
|
@ -0,0 +1,55 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
import synthetic_data_kit as sdk
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Message
|
||||||
|
from llama_stack.apis.synthetic_data_generation import (
|
||||||
|
FilteringFunction,
|
||||||
|
SyntheticDataGeneration,
|
||||||
|
SyntheticDataGenerationResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .config import SyntheticDataKitConfig
|
||||||
|
|
||||||
|
|
||||||
|
class SyntheticDataKitProvider(SyntheticDataGeneration):
|
||||||
|
def __init__(self, config: SyntheticDataKitConfig):
|
||||||
|
self.config = config
|
||||||
|
self.sdk = sdk.SyntheticDataKit(
|
||||||
|
llm=self.config.llm,
|
||||||
|
vllm=self.config.vllm,
|
||||||
|
generation=self.config.generation,
|
||||||
|
curate=self.config.curate,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def synthetic_data_generate(
|
||||||
|
self,
|
||||||
|
dialogs: list[Message],
|
||||||
|
filtering_function: FilteringFunction = FilteringFunction.none,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
) -> SyntheticDataGenerationResponse:
|
||||||
|
# Convert dialogs to text format
|
||||||
|
text_content = "\n".join(d.content for d in dialogs)
|
||||||
|
|
||||||
|
# Generate synthetic data
|
||||||
|
if filtering_function == FilteringFunction.none:
|
||||||
|
result = await self.sdk.create(text_content, type="qa")
|
||||||
|
else:
|
||||||
|
# Generate and then curate
|
||||||
|
generated = await self.sdk.create(text_content, type="qa")
|
||||||
|
result = await self.sdk.curate(generated)
|
||||||
|
|
||||||
|
return SyntheticDataGenerationResponse(
|
||||||
|
synthetic_data=result.get("synthetic_data", []),
|
||||||
|
statistics=result.get("statistics"),
|
||||||
|
)
|
28
llama_stack/providers/registry/synthetic_data_generation.py
Normal file
28
llama_stack/providers/registry/synthetic_data_generation.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
|
def available_providers() -> list[ProviderSpec]:
|
||||||
|
return [
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.synthetic_data_generation,
|
||||||
|
provider_type="inline::synthetic_data_kit",
|
||||||
|
pip_packages=[
|
||||||
|
"synthetic-data-kit",
|
||||||
|
"vllm",
|
||||||
|
"pydantic",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.inline.synthetic_data_generation.synthetic_data_kit_inline",
|
||||||
|
config_class="llama_stack.providers.inline.synthetic_data_generation.config.SyntheticDataKitConfig",
|
||||||
|
),
|
||||||
|
]
|
Loading…
Add table
Add a link
Reference in a new issue