This commit is contained in:
Alina Ryan 2025-07-24 19:57:24 -04:00 committed by GitHub
commit fbf4a0141f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 442 additions and 0 deletions

View file

@ -103,6 +103,8 @@ class Api(Enum, metaclass=DynamicApiMeta):
tool_groups = "tool_groups"
files = "files"
synthetic_data_generation = "synthetic_data_generation"
# built-in API
inspect = "inspect"

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# SPDX-License-Identifier: MIT
from typing import cast
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.providers.utils.resolver import get_provider_impl as _get_provider_impl
def get_provider_impl() -> SyntheticDataGeneration:
return cast(SyntheticDataGeneration, _get_provider_impl(SyntheticDataGeneration))

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# SPDX-License-Identifier: MIT
import requests
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Message
from llama_stack.apis.synthetic_data_generation import (
FilteringFunction,
SyntheticDataGeneration,
SyntheticDataGenerationResponse,
)
class SyntheticDataKitConfig(BaseModel):
"""Configuration for the Synthetic Data Kit provider"""
llm: Dict[str, Any] = Field(
default_factory=lambda: {
"provider": "vllm",
"model": "meta-llama/Llama-3.2-3B-Instruct",
}
)
vllm: Dict[str, Any] = Field(
default_factory=lambda: {
"api_base": "http://localhost:8000/v1",
}
)
generation: Dict[str, Any] = Field(
default_factory=lambda: {
"temperature": 0.7,
"chunk_size": 4000,
"num_pairs": 25,
}
)
curate: Dict[str, Any] = Field(
default_factory=lambda: {
"threshold": 7.0,
"batch_size": 8,
}
)
@classmethod
def sample_run_config(cls) -> "SyntheticDataKitConfig":
"""Create a sample configuration for testing"""
return cls()
class SyntheticDataKitProvider(SyntheticDataGeneration):
def __init__(self, config: SyntheticDataKitConfig):
self.config = config
self._validate_connection()
def _validate_connection(self) -> None:
"""Validate connection to vLLM server"""
try:
response = requests.get(f"http://localhost:{self.config.vllm['port']}/health")
response.raise_for_status()
except Exception as e:
raise RuntimeError(f"Failed to connect to vLLM server: {e}") from e
def synthetic_data_generate(
self,
dialogs: list[Message],
filtering_function: FilteringFunction = FilteringFunction.none,
model: str | None = None,
) -> SyntheticDataGenerationResponse:
# Convert dialogs to SDK format
formatted_dialogs = [{"role": dialog.role, "content": dialog.content} for dialog in dialogs]
payload = {
"dialogs": formatted_dialogs,
"filtering_function": filtering_function.value,
"model": model or self.config.llm["model"],
"generation": self.config.generation,
"curate": self.config.curate if filtering_function != FilteringFunction.none else None,
}
try:
response = requests.post(
f"http://localhost:{self.config.vllm['port']}/v1/synthetic-data-generation/generate",
json=payload,
)
response.raise_for_status()
result = response.json()
return SyntheticDataGenerationResponse(
synthetic_data=result.get("synthetic_data", []),
statistics=result.get("statistics"),
)
except requests.exceptions.RequestException as e:
raise RuntimeError(f"Synthetic data generation failed: {e}") from e

View file

@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# SPDX-License-Identifier: MIT
from typing import Optional
import synthetic_data_kit as sdk
from llama_stack.apis.inference import Message
from llama_stack.apis.synthetic_data_generation import (
FilteringFunction,
SyntheticDataGeneration,
SyntheticDataGenerationResponse,
)
from .config import SyntheticDataKitConfig
class SyntheticDataKitProvider(SyntheticDataGeneration):
def __init__(self, config: SyntheticDataKitConfig):
self.config = config
self.sdk = sdk.SyntheticDataKit(
llm=self.config.llm,
vllm=self.config.vllm,
generation=self.config.generation,
curate=self.config.curate,
)
async def synthetic_data_generate(
self,
dialogs: list[Message],
filtering_function: FilteringFunction = FilteringFunction.none,
model: Optional[str] = None,
) -> SyntheticDataGenerationResponse:
# Convert dialogs to text format
text_content = "\n".join(d.content for d in dialogs)
# Generate synthetic data
if filtering_function == FilteringFunction.none:
result = await self.sdk.create(text_content, type="qa")
else:
# Generate and then curate
generated = await self.sdk.create(text_content, type="qa")
result = await self.sdk.curate(generated)
return SyntheticDataGenerationResponse(
synthetic_data=result.get("synthetic_data", []),
statistics=result.get("statistics"),
)

View file

@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# SPDX-License-Identifier: MIT
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.synthetic_data_generation,
provider_type="inline::synthetic_data_kit",
pip_packages=[
"synthetic-data-kit",
"vllm",
"pydantic",
],
module="llama_stack.providers.inline.synthetic_data_generation.synthetic_data_kit_inline",
config_class="llama_stack.providers.inline.synthetic_data_generation.config.SyntheticDataKitConfig",
),
]