mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Merge cc03093705
into 40fdce79b3
This commit is contained in:
commit
7350bccc9d
8 changed files with 442 additions and 0 deletions
|
@ -35,6 +35,8 @@ class Api(Enum):
|
|||
tool_groups = "tool_groups"
|
||||
files = "files"
|
||||
|
||||
synthetic_data_generation = "synthetic_data_generation"
|
||||
|
||||
# built-in API
|
||||
inspect = "inspect"
|
||||
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from typing import cast
|
||||
|
||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||
from llama_stack.providers.utils.resolver import get_provider_impl as _get_provider_impl
|
||||
|
||||
|
||||
def get_provider_impl() -> SyntheticDataGeneration:
|
||||
return cast(SyntheticDataGeneration, _get_provider_impl(SyntheticDataGeneration))
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import requests
|
||||
from typing import Any, Dict, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.synthetic_data_generation import (
|
||||
FilteringFunction,
|
||||
SyntheticDataGeneration,
|
||||
SyntheticDataGenerationResponse,
|
||||
)
|
||||
|
||||
|
||||
class SyntheticDataKitConfig(BaseModel):
|
||||
"""Configuration for the Synthetic Data Kit provider"""
|
||||
llm: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"provider": "vllm",
|
||||
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
||||
}
|
||||
)
|
||||
vllm: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"api_base": "http://localhost:8000/v1",
|
||||
}
|
||||
)
|
||||
generation: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"temperature": 0.7,
|
||||
"chunk_size": 4000,
|
||||
"num_pairs": 25,
|
||||
}
|
||||
)
|
||||
curate: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"threshold": 7.0,
|
||||
"batch_size": 8,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls) -> "SyntheticDataKitConfig":
|
||||
"""Create a sample configuration for testing"""
|
||||
return cls()
|
||||
|
||||
|
||||
class SyntheticDataKitProvider(SyntheticDataGeneration):
|
||||
def __init__(self, config: SyntheticDataKitConfig):
|
||||
self.config = config
|
||||
self._validate_connection()
|
||||
|
||||
def _validate_connection(self) -> None:
|
||||
"""Validate connection to vLLM server"""
|
||||
try:
|
||||
response = requests.get(f"http://localhost:{self.config.vllm['port']}/health")
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to connect to vLLM server: {e}") from e
|
||||
|
||||
def synthetic_data_generate(
|
||||
self,
|
||||
dialogs: list[Message],
|
||||
filtering_function: FilteringFunction = FilteringFunction.none,
|
||||
model: str | None = None,
|
||||
) -> SyntheticDataGenerationResponse:
|
||||
# Convert dialogs to SDK format
|
||||
formatted_dialogs = [{"role": dialog.role, "content": dialog.content} for dialog in dialogs]
|
||||
|
||||
payload = {
|
||||
"dialogs": formatted_dialogs,
|
||||
"filtering_function": filtering_function.value,
|
||||
"model": model or self.config.llm["model"],
|
||||
"generation": self.config.generation,
|
||||
"curate": self.config.curate if filtering_function != FilteringFunction.none else None,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"http://localhost:{self.config.vllm['port']}/v1/synthetic-data-generation/generate",
|
||||
json=payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
return SyntheticDataGenerationResponse(
|
||||
synthetic_data=result.get("synthetic_data", []),
|
||||
statistics=result.get("statistics"),
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(f"Synthetic data generation failed: {e}") from e
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from typing import Optional
|
||||
import synthetic_data_kit as sdk
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.synthetic_data_generation import (
|
||||
FilteringFunction,
|
||||
SyntheticDataGeneration,
|
||||
SyntheticDataGenerationResponse,
|
||||
)
|
||||
|
||||
from .config import SyntheticDataKitConfig
|
||||
|
||||
|
||||
class SyntheticDataKitProvider(SyntheticDataGeneration):
|
||||
def __init__(self, config: SyntheticDataKitConfig):
|
||||
self.config = config
|
||||
self.sdk = sdk.SyntheticDataKit(
|
||||
llm=self.config.llm,
|
||||
vllm=self.config.vllm,
|
||||
generation=self.config.generation,
|
||||
curate=self.config.curate,
|
||||
)
|
||||
|
||||
async def synthetic_data_generate(
|
||||
self,
|
||||
dialogs: list[Message],
|
||||
filtering_function: FilteringFunction = FilteringFunction.none,
|
||||
model: Optional[str] = None,
|
||||
) -> SyntheticDataGenerationResponse:
|
||||
# Convert dialogs to text format
|
||||
text_content = "\n".join(d.content for d in dialogs)
|
||||
|
||||
# Generate synthetic data
|
||||
if filtering_function == FilteringFunction.none:
|
||||
result = await self.sdk.create(text_content, type="qa")
|
||||
else:
|
||||
# Generate and then curate
|
||||
generated = await self.sdk.create(text_content, type="qa")
|
||||
result = await self.sdk.curate(generated)
|
||||
|
||||
return SyntheticDataGenerationResponse(
|
||||
synthetic_data=result.get("synthetic_data", []),
|
||||
statistics=result.get("statistics"),
|
||||
)
|
28
llama_stack/providers/registry/synthetic_data_generation.py
Normal file
28
llama_stack/providers/registry/synthetic_data_generation.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.synthetic_data_generation,
|
||||
provider_type="inline::synthetic_data_kit",
|
||||
pip_packages=[
|
||||
"synthetic-data-kit",
|
||||
"vllm",
|
||||
"pydantic",
|
||||
],
|
||||
module="llama_stack.providers.inline.synthetic_data_generation.synthetic_data_kit_inline",
|
||||
config_class="llama_stack.providers.inline.synthetic_data_generation.config.SyntheticDataKitConfig",
|
||||
),
|
||||
]
|
|
@ -0,0 +1,101 @@
|
|||
import os
|
||||
import pytest
|
||||
from typing import cast
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.synthetic_data_generation import (
|
||||
SyntheticDataGeneration,
|
||||
FilteringFunction,
|
||||
)
|
||||
from llama_stack.apis.synthetic_data_generation.providers import get_provider_impl
|
||||
from llama_stack.distribution.client import LlamaStackAsLibraryClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
# Use LlamaStackAsLibraryClient for inline testing
|
||||
return LlamaStackAsLibraryClient()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthetic_data_kit_provider_integration(client: LlamaStackAsLibraryClient):
|
||||
provider = await get_provider_impl()
|
||||
assert isinstance(provider, SyntheticDataGeneration)
|
||||
|
||||
# Test single message generation
|
||||
dialogs = [
|
||||
Message(role="user", content="What is artificial intelligence?"),
|
||||
]
|
||||
|
||||
response = await provider.synthetic_data_generate(
|
||||
dialogs=dialogs,
|
||||
filtering_function=FilteringFunction.none,
|
||||
)
|
||||
|
||||
assert response.synthetic_data is not None
|
||||
assert len(response.synthetic_data) > 0
|
||||
assert all(isinstance(item, dict) for item in response.synthetic_data)
|
||||
assert all("question" in item and "answer" in item for item in response.synthetic_data)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthetic_data_kit_provider_with_filtering(client: LlamaStackAsLibraryClient):
|
||||
provider = await get_provider_impl()
|
||||
|
||||
# Test generation with filtering
|
||||
dialogs = [
|
||||
Message(role="user", content="Explain quantum computing."),
|
||||
Message(role="assistant", content="Quantum computing uses quantum mechanics..."),
|
||||
]
|
||||
|
||||
response = await provider.synthetic_data_generate(
|
||||
dialogs=dialogs,
|
||||
filtering_function=FilteringFunction.top_k,
|
||||
)
|
||||
|
||||
assert response.synthetic_data is not None
|
||||
assert len(response.synthetic_data) > 0
|
||||
assert response.statistics is not None
|
||||
assert "threshold" in response.statistics
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthetic_data_kit_provider_error_handling(client: LlamaStackAsLibraryClient):
|
||||
provider = await get_provider_impl()
|
||||
|
||||
# Test with empty dialogs
|
||||
with pytest.raises(ValueError):
|
||||
await provider.synthetic_data_generate(
|
||||
dialogs=[],
|
||||
filtering_function=FilteringFunction.none,
|
||||
)
|
||||
|
||||
# Test with invalid model
|
||||
with pytest.raises(RuntimeError):
|
||||
await provider.synthetic_data_generate(
|
||||
dialogs=[Message(role="user", content="Test")],
|
||||
filtering_function=FilteringFunction.none,
|
||||
model="invalid-model",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthetic_data_kit_provider_with_env_config(client: LlamaStackAsLibraryClient):
|
||||
# Set environment variables for testing
|
||||
os.environ["SYNTHETIC_DATA_KIT_MODEL"] = "meta-llama/Llama-3.2-7B-Instruct"
|
||||
|
||||
provider = await get_provider_impl()
|
||||
dialogs = [
|
||||
Message(role="user", content="What is deep learning?"),
|
||||
Message(role="assistant", content="Deep learning is a subset of machine learning..."),
|
||||
]
|
||||
|
||||
response = await provider.synthetic_data_generate(
|
||||
dialogs=dialogs,
|
||||
filtering_function=FilteringFunction.none,
|
||||
)
|
||||
|
||||
assert response.synthetic_data is not None
|
||||
assert len(response.synthetic_data) > 0
|
||||
# Clean up environment
|
||||
del os.environ["SYNTHETIC_DATA_KIT_MODEL"]
|
|
@ -0,0 +1,132 @@
|
|||
import os
|
||||
import pytest
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.synthetic_data_generation import (
|
||||
SyntheticDataGeneration,
|
||||
SyntheticDataGenerationResponse,
|
||||
FilteringFunction,
|
||||
)
|
||||
from llama_stack.providers.inline.synthetic_data_generation.synthetic_data_kit.config import (
|
||||
SyntheticDataKitConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.synthetic_data_generation.synthetic_data_kit.synthetic_data_kit import (
|
||||
SyntheticDataKitProvider,
|
||||
)
|
||||
|
||||
|
||||
def test_config_defaults():
|
||||
"""Test default configuration values"""
|
||||
config = SyntheticDataKitConfig()
|
||||
assert config.llm["provider"] == "vllm"
|
||||
assert config.llm["model"] == "meta-llama/Llama-3.2-3B-Instruct"
|
||||
assert config.vllm["api_base"] == "http://localhost:8000/v1"
|
||||
assert config.generation["temperature"] == 0.7
|
||||
assert config.generation["chunk_size"] == 4000
|
||||
assert config.curate["threshold"] == 7.0
|
||||
|
||||
|
||||
def test_sample_run_config():
|
||||
"""Test sample configuration with environment variables"""
|
||||
# Test default configuration
|
||||
config = SyntheticDataKitConfig.sample_run_config()
|
||||
assert isinstance(config, SyntheticDataKitConfig)
|
||||
assert config.llm["model"] == "meta-llama/Llama-3.2-3B-Instruct"
|
||||
|
||||
# Test environment variable override
|
||||
os.environ["SYNTHETIC_DATA_KIT_MODEL"] = "meta-llama/Llama-3.2-7B-Instruct"
|
||||
config = SyntheticDataKitConfig.sample_run_config()
|
||||
assert config.llm["model"] == "meta-llama/Llama-3.2-7B-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sdk():
|
||||
"""Create a mock SDK instance"""
|
||||
with patch("synthetic_data_kit.SyntheticDataKit") as mock:
|
||||
sdk_instance = MagicMock()
|
||||
sdk_instance.create = AsyncMock()
|
||||
sdk_instance.curate = AsyncMock()
|
||||
mock.return_value = sdk_instance
|
||||
yield sdk_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
return SyntheticDataKitConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider(config: SyntheticDataKitConfig, mock_sdk):
|
||||
return SyntheticDataKitProvider(config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthetic_data_generate_basic(provider: SyntheticDataGeneration, mock_sdk):
|
||||
# Setup mock response
|
||||
mock_sdk.create.return_value = {
|
||||
"synthetic_data": [{"question": "What is ML?", "answer": "Machine learning..."}],
|
||||
"statistics": {"count": 1}
|
||||
}
|
||||
|
||||
dialogs = [Message(role="user", content="What is machine learning?")]
|
||||
response = await provider.synthetic_data_generate(
|
||||
dialogs=dialogs,
|
||||
filtering_function=FilteringFunction.none,
|
||||
)
|
||||
|
||||
# Verify SDK was called correctly
|
||||
mock_sdk.create.assert_called_once_with("What is machine learning?", type="qa")
|
||||
assert isinstance(response, SyntheticDataGenerationResponse)
|
||||
assert len(response.synthetic_data) == 1
|
||||
assert response.statistics == {"count": 1}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthetic_data_generate_with_filtering(provider: SyntheticDataGeneration, mock_sdk):
|
||||
# Setup mock responses
|
||||
mock_sdk.create.return_value = {
|
||||
"synthetic_data": [{"question": "What is quantum?", "answer": "Quantum..."}],
|
||||
}
|
||||
mock_sdk.curate.return_value = {
|
||||
"synthetic_data": [{"question": "What is quantum?", "answer": "Quantum..."}],
|
||||
"statistics": {"threshold": 7.5}
|
||||
}
|
||||
|
||||
dialogs = [Message(role="user", content="Explain quantum computing.")]
|
||||
response = await provider.synthetic_data_generate(
|
||||
dialogs=dialogs,
|
||||
filtering_function=FilteringFunction.top_k,
|
||||
)
|
||||
|
||||
# Verify both create and curate were called
|
||||
mock_sdk.create.assert_called_once_with("Explain quantum computing.", type="qa")
|
||||
mock_sdk.curate.assert_called_once()
|
||||
assert isinstance(response, SyntheticDataGenerationResponse)
|
||||
assert response.statistics["threshold"] == 7.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthetic_data_generate_multiple_messages(provider: SyntheticDataGeneration, mock_sdk):
|
||||
mock_sdk.create.return_value = {
|
||||
"synthetic_data": [{"question": "What is deep learning?", "answer": "Deep..."}],
|
||||
"statistics": {"count": 1}
|
||||
}
|
||||
|
||||
dialogs = [
|
||||
Message(role="user", content="What is deep learning?"),
|
||||
Message(role="assistant", content="Deep learning is..."),
|
||||
Message(role="user", content="Can you explain more?")
|
||||
]
|
||||
|
||||
response = await provider.synthetic_data_generate(
|
||||
dialogs=dialogs,
|
||||
filtering_function=FilteringFunction.none,
|
||||
)
|
||||
|
||||
# Verify content was joined correctly
|
||||
expected_content = "What is deep learning?\nDeep learning is...\nCan you explain more?"
|
||||
mock_sdk.create.assert_called_once_with(expected_content, type="qa")
|
||||
assert isinstance(response, SyntheticDataGenerationResponse)
|
||||
assert response.synthetic_data is not None
|
Loading…
Add table
Add a link
Reference in a new issue