diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index 63a764725..17a975a93 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -35,6 +35,8 @@ class Api(Enum): tool_groups = "tool_groups" files = "files" + synthetic_data_generation = "synthetic_data_generation" + # built-in API inspect = "inspect" diff --git a/llama_stack/apis/synthetic_data_generation/providers/__init__.py b/llama_stack/apis/synthetic_data_generation/providers/__init__.py new file mode 100644 index 000000000..2cb43c960 --- /dev/null +++ b/llama_stack/apis/synthetic_data_generation/providers/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# SPDX-License-Identifier: MIT + +from typing import cast + +from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration +from llama_stack.providers.utils.resolver import get_provider_impl as _get_provider_impl + + +def get_provider_impl() -> SyntheticDataGeneration: + return cast(SyntheticDataGeneration, _get_provider_impl(SyntheticDataGeneration)) diff --git a/llama_stack/apis/synthetic_data_generation/registry.py b/llama_stack/apis/synthetic_data_generation/registry.py deleted file mode 100644 index ba097f610..000000000 --- a/llama_stack/apis/synthetic_data_generation/registry.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration - -SYNTHETIC_DATA_GENERATION_PROVIDERS: dict[str, SyntheticDataGeneration] = {} - - -def get_provider(name: str = "meta_synthetic_data_kit") -> SyntheticDataGeneration: - raise NotImplementedError(f"No provider registered yet for synthetic_data_generation (requested: {name})") diff --git a/llama_stack/providers/inline/synthetic_data_generation/synthetic_data_kit/__init__.py b/llama_stack/providers/inline/synthetic_data_generation/synthetic_data_kit/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/synthetic_data_generation/synthetic_data_kit/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/synthetic_data_generation/synthetic_data_kit/config.py b/llama_stack/providers/inline/synthetic_data_generation/synthetic_data_kit/config.py new file mode 100644 index 000000000..43108626c --- /dev/null +++ b/llama_stack/providers/inline/synthetic_data_generation/synthetic_data_kit/config.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# SPDX-License-Identifier: MIT + +import requests +from typing import Any, Dict, Optional +from pydantic import BaseModel, Field + +from llama_stack.apis.inference import Message +from llama_stack.apis.synthetic_data_generation import ( + FilteringFunction, + SyntheticDataGeneration, + SyntheticDataGenerationResponse, +) + + +class SyntheticDataKitConfig(BaseModel): + """Configuration for the Synthetic Data Kit provider""" + llm: Dict[str, Any] = Field( + default_factory=lambda: { + "provider": "vllm", + "model": "meta-llama/Llama-3.2-3B-Instruct", + } + ) + vllm: Dict[str, Any] = Field( + default_factory=lambda: { + "api_base": "http://localhost:8000/v1", + } + ) + generation: Dict[str, Any] = Field( + default_factory=lambda: { + "temperature": 0.7, + "chunk_size": 4000, + "num_pairs": 25, + } + ) + curate: Dict[str, Any] = Field( + default_factory=lambda: { + "threshold": 7.0, + "batch_size": 8, + } + ) + + @classmethod + def sample_run_config(cls) -> "SyntheticDataKitConfig": + """Create a sample configuration for testing""" + return cls() + + +class SyntheticDataKitProvider(SyntheticDataGeneration): + def __init__(self, config: SyntheticDataKitConfig): + self.config = config + self._validate_connection() + + def _validate_connection(self) -> None: + """Validate connection to vLLM server""" + try: + response = requests.get(f"http://localhost:{self.config.vllm['port']}/health") + response.raise_for_status() + except Exception as e: + raise RuntimeError(f"Failed to connect to vLLM server: {e}") from e + + def synthetic_data_generate( + self, + dialogs: list[Message], + filtering_function: FilteringFunction = FilteringFunction.none, + model: str | None = None, + ) -> SyntheticDataGenerationResponse: + # Convert dialogs to SDK format + formatted_dialogs = [{"role": dialog.role, "content": dialog.content} for dialog in dialogs] + + payload = { + "dialogs": formatted_dialogs, + "filtering_function": filtering_function.value, + "model": model or self.config.llm["model"], + "generation": self.config.generation, + "curate": self.config.curate if filtering_function != FilteringFunction.none else None, + } + + try: + response = requests.post( + f"http://localhost:{self.config.vllm['port']}/v1/synthetic-data-generation/generate", + json=payload, + ) + response.raise_for_status() + result = response.json() + + return SyntheticDataGenerationResponse( + synthetic_data=result.get("synthetic_data", []), + statistics=result.get("statistics"), + ) + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Synthetic data generation failed: {e}") from e diff --git a/llama_stack/providers/inline/synthetic_data_generation/synthetic_data_kit/synthetic_data_kit.py b/llama_stack/providers/inline/synthetic_data_generation/synthetic_data_kit/synthetic_data_kit.py new file mode 100644 index 000000000..02ee4caba --- /dev/null +++ b/llama_stack/providers/inline/synthetic_data_generation/synthetic_data_kit/synthetic_data_kit.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# SPDX-License-Identifier: MIT + +from typing import Optional +import synthetic_data_kit as sdk + +from llama_stack.apis.inference import Message +from llama_stack.apis.synthetic_data_generation import ( + FilteringFunction, + SyntheticDataGeneration, + SyntheticDataGenerationResponse, +) + +from .config import SyntheticDataKitConfig + + +class SyntheticDataKitProvider(SyntheticDataGeneration): + def __init__(self, config: SyntheticDataKitConfig): + self.config = config + self.sdk = sdk.SyntheticDataKit( + llm=self.config.llm, + vllm=self.config.vllm, + generation=self.config.generation, + curate=self.config.curate, + ) + + async def synthetic_data_generate( + self, + dialogs: list[Message], + filtering_function: FilteringFunction = FilteringFunction.none, + model: Optional[str] = None, + ) -> SyntheticDataGenerationResponse: + # Convert dialogs to text format + text_content = "\n".join(d.content for d in dialogs) + + # Generate synthetic data + if filtering_function == FilteringFunction.none: + result = await self.sdk.create(text_content, type="qa") + else: + # Generate and then curate + generated = await self.sdk.create(text_content, type="qa") + result = await self.sdk.curate(generated) + + return SyntheticDataGenerationResponse( + synthetic_data=result.get("synthetic_data", []), + statistics=result.get("statistics"), + ) diff --git a/llama_stack/providers/registry/synthetic_data_generation.py b/llama_stack/providers/registry/synthetic_data_generation.py new file mode 100644 index 000000000..318187117 --- /dev/null +++ b/llama_stack/providers/registry/synthetic_data_generation.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# SPDX-License-Identifier: MIT + +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec + + +def available_providers() -> list[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.synthetic_data_generation, + provider_type="inline::synthetic_data_kit", + pip_packages=[ + "synthetic-data-kit", + "vllm", + "pydantic", + ], + module="llama_stack.providers.inline.synthetic_data_generation.synthetic_data_kit_inline", + config_class="llama_stack.providers.inline.synthetic_data_generation.config.SyntheticDataKitConfig", + ), + ]