diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index 63a764725..17a975a93 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -35,6 +35,8 @@ class Api(Enum): tool_groups = "tool_groups" files = "files" + synthetic_data_generation = "synthetic_data_generation" + # built-in API inspect = "inspect" diff --git a/llama_stack/apis/synthetic_data_generation/providers/__init__.py b/llama_stack/apis/synthetic_data_generation/providers/__init__.py new file mode 100644 index 000000000..2cb43c960 --- /dev/null +++ b/llama_stack/apis/synthetic_data_generation/providers/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# SPDX-License-Identifier: MIT + +from typing import cast + +from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration +from llama_stack.providers.utils.resolver import get_provider_impl as _get_provider_impl + + +def get_provider_impl() -> SyntheticDataGeneration: + return cast(SyntheticDataGeneration, _get_provider_impl(SyntheticDataGeneration)) diff --git a/llama_stack/providers/inline/synthetic_data_generation/synthetic_data_kit/__init__.py b/llama_stack/providers/inline/synthetic_data_generation/synthetic_data_kit/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/synthetic_data_generation/synthetic_data_kit/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/synthetic_data_generation/synthetic_data_kit/config.py b/llama_stack/providers/inline/synthetic_data_generation/synthetic_data_kit/config.py new file mode 100644 index 000000000..43108626c --- /dev/null +++ b/llama_stack/providers/inline/synthetic_data_generation/synthetic_data_kit/config.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# SPDX-License-Identifier: MIT + +import requests +from typing import Any, Dict, Optional +from pydantic import BaseModel, Field + +from llama_stack.apis.inference import Message +from llama_stack.apis.synthetic_data_generation import ( + FilteringFunction, + SyntheticDataGeneration, + SyntheticDataGenerationResponse, +) + + +class SyntheticDataKitConfig(BaseModel): + """Configuration for the Synthetic Data Kit provider""" + llm: Dict[str, Any] = Field( + default_factory=lambda: { + "provider": "vllm", + "model": "meta-llama/Llama-3.2-3B-Instruct", + } + ) + vllm: Dict[str, Any] = Field( + default_factory=lambda: { + "api_base": "http://localhost:8000/v1", + } + ) + generation: Dict[str, Any] = Field( + default_factory=lambda: { + "temperature": 0.7, + "chunk_size": 4000, + "num_pairs": 25, + } + ) + curate: Dict[str, Any] = Field( + default_factory=lambda: { + "threshold": 7.0, + "batch_size": 8, + } + ) + + @classmethod + def sample_run_config(cls) -> "SyntheticDataKitConfig": + """Create a sample configuration for testing""" + return cls() + + +class SyntheticDataKitProvider(SyntheticDataGeneration): + def __init__(self, config: SyntheticDataKitConfig): + self.config = config + self._validate_connection() + + def _validate_connection(self) -> None: + """Validate connection to vLLM server""" + try: + response = requests.get(f"http://localhost:{self.config.vllm['port']}/health") + response.raise_for_status() + except Exception as e: + raise RuntimeError(f"Failed to connect to vLLM server: {e}") from e + + def synthetic_data_generate( + self, + dialogs: list[Message], + filtering_function: FilteringFunction = FilteringFunction.none, + model: str | None = None, + ) -> SyntheticDataGenerationResponse: + # Convert dialogs to SDK format + formatted_dialogs = [{"role": dialog.role, "content": dialog.content} for dialog in dialogs] + + payload = { + "dialogs": formatted_dialogs, + "filtering_function": filtering_function.value, + "model": model or self.config.llm["model"], + "generation": self.config.generation, + "curate": self.config.curate if filtering_function != FilteringFunction.none else None, + } + + try: + response = requests.post( + f"http://localhost:{self.config.vllm['port']}/v1/synthetic-data-generation/generate", + json=payload, + ) + response.raise_for_status() + result = response.json() + + return SyntheticDataGenerationResponse( + synthetic_data=result.get("synthetic_data", []), + statistics=result.get("statistics"), + ) + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Synthetic data generation failed: {e}") from e diff --git a/llama_stack/providers/inline/synthetic_data_generation/synthetic_data_kit/synthetic_data_kit.py b/llama_stack/providers/inline/synthetic_data_generation/synthetic_data_kit/synthetic_data_kit.py new file mode 100644 index 000000000..02ee4caba --- /dev/null +++ b/llama_stack/providers/inline/synthetic_data_generation/synthetic_data_kit/synthetic_data_kit.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# SPDX-License-Identifier: MIT + +from typing import Optional +import synthetic_data_kit as sdk + +from llama_stack.apis.inference import Message +from llama_stack.apis.synthetic_data_generation import ( + FilteringFunction, + SyntheticDataGeneration, + SyntheticDataGenerationResponse, +) + +from .config import SyntheticDataKitConfig + + +class SyntheticDataKitProvider(SyntheticDataGeneration): + def __init__(self, config: SyntheticDataKitConfig): + self.config = config + self.sdk = sdk.SyntheticDataKit( + llm=self.config.llm, + vllm=self.config.vllm, + generation=self.config.generation, + curate=self.config.curate, + ) + + async def synthetic_data_generate( + self, + dialogs: list[Message], + filtering_function: FilteringFunction = FilteringFunction.none, + model: Optional[str] = None, + ) -> SyntheticDataGenerationResponse: + # Convert dialogs to text format + text_content = "\n".join(d.content for d in dialogs) + + # Generate synthetic data + if filtering_function == FilteringFunction.none: + result = await self.sdk.create(text_content, type="qa") + else: + # Generate and then curate + generated = await self.sdk.create(text_content, type="qa") + result = await self.sdk.curate(generated) + + return SyntheticDataGenerationResponse( + synthetic_data=result.get("synthetic_data", []), + statistics=result.get("statistics"), + ) diff --git a/llama_stack/providers/registry/synthetic_data_generation.py b/llama_stack/providers/registry/synthetic_data_generation.py new file mode 100644 index 000000000..318187117 --- /dev/null +++ b/llama_stack/providers/registry/synthetic_data_generation.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# SPDX-License-Identifier: MIT + +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec + + +def available_providers() -> list[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.synthetic_data_generation, + provider_type="inline::synthetic_data_kit", + pip_packages=[ + "synthetic-data-kit", + "vllm", + "pydantic", + ], + module="llama_stack.providers.inline.synthetic_data_generation.synthetic_data_kit_inline", + config_class="llama_stack.providers.inline.synthetic_data_generation.config.SyntheticDataKitConfig", + ), + ] diff --git a/tests/integration/providers/inline/synthetic_data_generation/test_synthetic_data_kit_integration.py b/tests/integration/providers/inline/synthetic_data_generation/test_synthetic_data_kit_integration.py new file mode 100644 index 000000000..e6e166cae --- /dev/null +++ b/tests/integration/providers/inline/synthetic_data_generation/test_synthetic_data_kit_integration.py @@ -0,0 +1,101 @@ +import os +import pytest +from typing import cast + +from llama_stack.apis.inference import Message +from llama_stack.apis.synthetic_data_generation import ( + SyntheticDataGeneration, + FilteringFunction, +) +from llama_stack.apis.synthetic_data_generation.providers import get_provider_impl +from llama_stack.distribution.client import LlamaStackAsLibraryClient + + +@pytest.fixture +async def client(): + # Use LlamaStackAsLibraryClient for inline testing + return LlamaStackAsLibraryClient() + + +@pytest.mark.asyncio +async def test_synthetic_data_kit_provider_integration(client: LlamaStackAsLibraryClient): + provider = await get_provider_impl() + assert isinstance(provider, SyntheticDataGeneration) + + # Test single message generation + dialogs = [ + Message(role="user", content="What is artificial intelligence?"), + ] + + response = await provider.synthetic_data_generate( + dialogs=dialogs, + filtering_function=FilteringFunction.none, + ) + + assert response.synthetic_data is not None + assert len(response.synthetic_data) > 0 + assert all(isinstance(item, dict) for item in response.synthetic_data) + assert all("question" in item and "answer" in item for item in response.synthetic_data) + + +@pytest.mark.asyncio +async def test_synthetic_data_kit_provider_with_filtering(client: LlamaStackAsLibraryClient): + provider = await get_provider_impl() + + # Test generation with filtering + dialogs = [ + Message(role="user", content="Explain quantum computing."), + Message(role="assistant", content="Quantum computing uses quantum mechanics..."), + ] + + response = await provider.synthetic_data_generate( + dialogs=dialogs, + filtering_function=FilteringFunction.top_k, + ) + + assert response.synthetic_data is not None + assert len(response.synthetic_data) > 0 + assert response.statistics is not None + assert "threshold" in response.statistics + + +@pytest.mark.asyncio +async def test_synthetic_data_kit_provider_error_handling(client: LlamaStackAsLibraryClient): + provider = await get_provider_impl() + + # Test with empty dialogs + with pytest.raises(ValueError): + await provider.synthetic_data_generate( + dialogs=[], + filtering_function=FilteringFunction.none, + ) + + # Test with invalid model + with pytest.raises(RuntimeError): + await provider.synthetic_data_generate( + dialogs=[Message(role="user", content="Test")], + filtering_function=FilteringFunction.none, + model="invalid-model", + ) + + +@pytest.mark.asyncio +async def test_synthetic_data_kit_provider_with_env_config(client: LlamaStackAsLibraryClient): + # Set environment variables for testing + os.environ["SYNTHETIC_DATA_KIT_MODEL"] = "meta-llama/Llama-3.2-7B-Instruct" + + provider = await get_provider_impl() + dialogs = [ + Message(role="user", content="What is deep learning?"), + Message(role="assistant", content="Deep learning is a subset of machine learning..."), + ] + + response = await provider.synthetic_data_generate( + dialogs=dialogs, + filtering_function=FilteringFunction.none, + ) + + assert response.synthetic_data is not None + assert len(response.synthetic_data) > 0 + # Clean up environment + del os.environ["SYNTHETIC_DATA_KIT_MODEL"] \ No newline at end of file diff --git a/tests/unit/providers/inline/synthetic_data_generation/test_synthetic_data_kit.py b/tests/unit/providers/inline/synthetic_data_generation/test_synthetic_data_kit.py new file mode 100644 index 000000000..c4d6d42e9 --- /dev/null +++ b/tests/unit/providers/inline/synthetic_data_generation/test_synthetic_data_kit.py @@ -0,0 +1,132 @@ +import os +import pytest +from typing import cast +from unittest.mock import AsyncMock, MagicMock, patch + +from llama_stack.apis.inference import Message +from llama_stack.apis.synthetic_data_generation import ( + SyntheticDataGeneration, + SyntheticDataGenerationResponse, + FilteringFunction, +) +from llama_stack.providers.inline.synthetic_data_generation.synthetic_data_kit.config import ( + SyntheticDataKitConfig, +) +from llama_stack.providers.inline.synthetic_data_generation.synthetic_data_kit.synthetic_data_kit import ( + SyntheticDataKitProvider, +) + + +def test_config_defaults(): + """Test default configuration values""" + config = SyntheticDataKitConfig() + assert config.llm["provider"] == "vllm" + assert config.llm["model"] == "meta-llama/Llama-3.2-3B-Instruct" + assert config.vllm["api_base"] == "http://localhost:8000/v1" + assert config.generation["temperature"] == 0.7 + assert config.generation["chunk_size"] == 4000 + assert config.curate["threshold"] == 7.0 + + +def test_sample_run_config(): + """Test sample configuration with environment variables""" + # Test default configuration + config = SyntheticDataKitConfig.sample_run_config() + assert isinstance(config, SyntheticDataKitConfig) + assert config.llm["model"] == "meta-llama/Llama-3.2-3B-Instruct" + + # Test environment variable override + os.environ["SYNTHETIC_DATA_KIT_MODEL"] = "meta-llama/Llama-3.2-7B-Instruct" + config = SyntheticDataKitConfig.sample_run_config() + assert config.llm["model"] == "meta-llama/Llama-3.2-7B-Instruct" + + +@pytest.fixture +def mock_sdk(): + """Create a mock SDK instance""" + with patch("synthetic_data_kit.SyntheticDataKit") as mock: + sdk_instance = MagicMock() + sdk_instance.create = AsyncMock() + sdk_instance.curate = AsyncMock() + mock.return_value = sdk_instance + yield sdk_instance + + +@pytest.fixture +def config(): + return SyntheticDataKitConfig() + + +@pytest.fixture +def provider(config: SyntheticDataKitConfig, mock_sdk): + return SyntheticDataKitProvider(config) + + +@pytest.mark.asyncio +async def test_synthetic_data_generate_basic(provider: SyntheticDataGeneration, mock_sdk): + # Setup mock response + mock_sdk.create.return_value = { + "synthetic_data": [{"question": "What is ML?", "answer": "Machine learning..."}], + "statistics": {"count": 1} + } + + dialogs = [Message(role="user", content="What is machine learning?")] + response = await provider.synthetic_data_generate( + dialogs=dialogs, + filtering_function=FilteringFunction.none, + ) + + # Verify SDK was called correctly + mock_sdk.create.assert_called_once_with("What is machine learning?", type="qa") + assert isinstance(response, SyntheticDataGenerationResponse) + assert len(response.synthetic_data) == 1 + assert response.statistics == {"count": 1} + + +@pytest.mark.asyncio +async def test_synthetic_data_generate_with_filtering(provider: SyntheticDataGeneration, mock_sdk): + # Setup mock responses + mock_sdk.create.return_value = { + "synthetic_data": [{"question": "What is quantum?", "answer": "Quantum..."}], + } + mock_sdk.curate.return_value = { + "synthetic_data": [{"question": "What is quantum?", "answer": "Quantum..."}], + "statistics": {"threshold": 7.5} + } + + dialogs = [Message(role="user", content="Explain quantum computing.")] + response = await provider.synthetic_data_generate( + dialogs=dialogs, + filtering_function=FilteringFunction.top_k, + ) + + # Verify both create and curate were called + mock_sdk.create.assert_called_once_with("Explain quantum computing.", type="qa") + mock_sdk.curate.assert_called_once() + assert isinstance(response, SyntheticDataGenerationResponse) + assert response.statistics["threshold"] == 7.5 + + +@pytest.mark.asyncio +async def test_synthetic_data_generate_multiple_messages(provider: SyntheticDataGeneration, mock_sdk): + mock_sdk.create.return_value = { + "synthetic_data": [{"question": "What is deep learning?", "answer": "Deep..."}], + "statistics": {"count": 1} + } + + dialogs = [ + Message(role="user", content="What is deep learning?"), + Message(role="assistant", content="Deep learning is..."), + Message(role="user", content="Can you explain more?") + ] + + response = await provider.synthetic_data_generate( + dialogs=dialogs, + filtering_function=FilteringFunction.none, + ) + + # Verify content was joined correctly + expected_content = "What is deep learning?\nDeep learning is...\nCan you explain more?" + mock_sdk.create.assert_called_once_with(expected_content, type="qa") + assert isinstance(response, SyntheticDataGenerationResponse) + assert response.synthetic_data is not None \ No newline at end of file