Add unit and integration tests for synthetic data kit provider

These tests follow Llama Stack's provider
testing guidelines to validate:
- Configuration handling and environment variables work as expected
- Provider implementation behaves correctly in both unit and integration scenarios
- Error cases are properly handled
- Integration with Llama Stack's client SDK functions properly

Signed-off-by: Alina Ryan <aliryan@redhat.com>
This commit is contained in:
Alina Ryan 2025-05-29 16:24:24 -04:00
parent f86f107f15
commit cc03093705
2 changed files with 233 additions and 0 deletions

View file

@ -0,0 +1,101 @@
import os
import pytest
from typing import cast
from llama_stack.apis.inference import Message
from llama_stack.apis.synthetic_data_generation import (
SyntheticDataGeneration,
FilteringFunction,
)
from llama_stack.apis.synthetic_data_generation.providers import get_provider_impl
from llama_stack.distribution.client import LlamaStackAsLibraryClient
@pytest.fixture
async def client():
# Use LlamaStackAsLibraryClient for inline testing
return LlamaStackAsLibraryClient()
@pytest.mark.asyncio
async def test_synthetic_data_kit_provider_integration(client: LlamaStackAsLibraryClient):
provider = await get_provider_impl()
assert isinstance(provider, SyntheticDataGeneration)
# Test single message generation
dialogs = [
Message(role="user", content="What is artificial intelligence?"),
]
response = await provider.synthetic_data_generate(
dialogs=dialogs,
filtering_function=FilteringFunction.none,
)
assert response.synthetic_data is not None
assert len(response.synthetic_data) > 0
assert all(isinstance(item, dict) for item in response.synthetic_data)
assert all("question" in item and "answer" in item for item in response.synthetic_data)
@pytest.mark.asyncio
async def test_synthetic_data_kit_provider_with_filtering(client: LlamaStackAsLibraryClient):
provider = await get_provider_impl()
# Test generation with filtering
dialogs = [
Message(role="user", content="Explain quantum computing."),
Message(role="assistant", content="Quantum computing uses quantum mechanics..."),
]
response = await provider.synthetic_data_generate(
dialogs=dialogs,
filtering_function=FilteringFunction.top_k,
)
assert response.synthetic_data is not None
assert len(response.synthetic_data) > 0
assert response.statistics is not None
assert "threshold" in response.statistics
@pytest.mark.asyncio
async def test_synthetic_data_kit_provider_error_handling(client: LlamaStackAsLibraryClient):
provider = await get_provider_impl()
# Test with empty dialogs
with pytest.raises(ValueError):
await provider.synthetic_data_generate(
dialogs=[],
filtering_function=FilteringFunction.none,
)
# Test with invalid model
with pytest.raises(RuntimeError):
await provider.synthetic_data_generate(
dialogs=[Message(role="user", content="Test")],
filtering_function=FilteringFunction.none,
model="invalid-model",
)
@pytest.mark.asyncio
async def test_synthetic_data_kit_provider_with_env_config(client: LlamaStackAsLibraryClient):
# Set environment variables for testing
os.environ["SYNTHETIC_DATA_KIT_MODEL"] = "meta-llama/Llama-3.2-7B-Instruct"
provider = await get_provider_impl()
dialogs = [
Message(role="user", content="What is deep learning?"),
Message(role="assistant", content="Deep learning is a subset of machine learning..."),
]
response = await provider.synthetic_data_generate(
dialogs=dialogs,
filtering_function=FilteringFunction.none,
)
assert response.synthetic_data is not None
assert len(response.synthetic_data) > 0
# Clean up environment
del os.environ["SYNTHETIC_DATA_KIT_MODEL"]

View file

@ -0,0 +1,132 @@
import os
import pytest
from typing import cast
from unittest.mock import AsyncMock, MagicMock, patch
from llama_stack.apis.inference import Message
from llama_stack.apis.synthetic_data_generation import (
SyntheticDataGeneration,
SyntheticDataGenerationResponse,
FilteringFunction,
)
from llama_stack.providers.inline.synthetic_data_generation.synthetic_data_kit.config import (
SyntheticDataKitConfig,
)
from llama_stack.providers.inline.synthetic_data_generation.synthetic_data_kit.synthetic_data_kit import (
SyntheticDataKitProvider,
)
def test_config_defaults():
"""Test default configuration values"""
config = SyntheticDataKitConfig()
assert config.llm["provider"] == "vllm"
assert config.llm["model"] == "meta-llama/Llama-3.2-3B-Instruct"
assert config.vllm["api_base"] == "http://localhost:8000/v1"
assert config.generation["temperature"] == 0.7
assert config.generation["chunk_size"] == 4000
assert config.curate["threshold"] == 7.0
def test_sample_run_config():
"""Test sample configuration with environment variables"""
# Test default configuration
config = SyntheticDataKitConfig.sample_run_config()
assert isinstance(config, SyntheticDataKitConfig)
assert config.llm["model"] == "meta-llama/Llama-3.2-3B-Instruct"
# Test environment variable override
os.environ["SYNTHETIC_DATA_KIT_MODEL"] = "meta-llama/Llama-3.2-7B-Instruct"
config = SyntheticDataKitConfig.sample_run_config()
assert config.llm["model"] == "meta-llama/Llama-3.2-7B-Instruct"
@pytest.fixture
def mock_sdk():
"""Create a mock SDK instance"""
with patch("synthetic_data_kit.SyntheticDataKit") as mock:
sdk_instance = MagicMock()
sdk_instance.create = AsyncMock()
sdk_instance.curate = AsyncMock()
mock.return_value = sdk_instance
yield sdk_instance
@pytest.fixture
def config():
return SyntheticDataKitConfig()
@pytest.fixture
def provider(config: SyntheticDataKitConfig, mock_sdk):
return SyntheticDataKitProvider(config)
@pytest.mark.asyncio
async def test_synthetic_data_generate_basic(provider: SyntheticDataGeneration, mock_sdk):
# Setup mock response
mock_sdk.create.return_value = {
"synthetic_data": [{"question": "What is ML?", "answer": "Machine learning..."}],
"statistics": {"count": 1}
}
dialogs = [Message(role="user", content="What is machine learning?")]
response = await provider.synthetic_data_generate(
dialogs=dialogs,
filtering_function=FilteringFunction.none,
)
# Verify SDK was called correctly
mock_sdk.create.assert_called_once_with("What is machine learning?", type="qa")
assert isinstance(response, SyntheticDataGenerationResponse)
assert len(response.synthetic_data) == 1
assert response.statistics == {"count": 1}
@pytest.mark.asyncio
async def test_synthetic_data_generate_with_filtering(provider: SyntheticDataGeneration, mock_sdk):
# Setup mock responses
mock_sdk.create.return_value = {
"synthetic_data": [{"question": "What is quantum?", "answer": "Quantum..."}],
}
mock_sdk.curate.return_value = {
"synthetic_data": [{"question": "What is quantum?", "answer": "Quantum..."}],
"statistics": {"threshold": 7.5}
}
dialogs = [Message(role="user", content="Explain quantum computing.")]
response = await provider.synthetic_data_generate(
dialogs=dialogs,
filtering_function=FilteringFunction.top_k,
)
# Verify both create and curate were called
mock_sdk.create.assert_called_once_with("Explain quantum computing.", type="qa")
mock_sdk.curate.assert_called_once()
assert isinstance(response, SyntheticDataGenerationResponse)
assert response.statistics["threshold"] == 7.5
@pytest.mark.asyncio
async def test_synthetic_data_generate_multiple_messages(provider: SyntheticDataGeneration, mock_sdk):
mock_sdk.create.return_value = {
"synthetic_data": [{"question": "What is deep learning?", "answer": "Deep..."}],
"statistics": {"count": 1}
}
dialogs = [
Message(role="user", content="What is deep learning?"),
Message(role="assistant", content="Deep learning is..."),
Message(role="user", content="Can you explain more?")
]
response = await provider.synthetic_data_generate(
dialogs=dialogs,
filtering_function=FilteringFunction.none,
)
# Verify content was joined correctly
expected_content = "What is deep learning?\nDeep learning is...\nCan you explain more?"
mock_sdk.create.assert_called_once_with(expected_content, type="qa")
assert isinstance(response, SyntheticDataGenerationResponse)
assert response.synthetic_data is not None