mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Add unit and integration tests for synthetic data kit provider
These tests follow Llama Stack's provider testing guidelines to validate: - Configuration handling and environment variables work as expected - Provider implementation behaves correctly in both unit and integration scenarios - Error cases are properly handled - Integration with Llama Stack's client SDK functions properly Signed-off-by: Alina Ryan <aliryan@redhat.com>
This commit is contained in:
parent
f86f107f15
commit
cc03093705
2 changed files with 233 additions and 0 deletions
|
@ -0,0 +1,101 @@
|
|||
import os
|
||||
import pytest
|
||||
from typing import cast
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.synthetic_data_generation import (
|
||||
SyntheticDataGeneration,
|
||||
FilteringFunction,
|
||||
)
|
||||
from llama_stack.apis.synthetic_data_generation.providers import get_provider_impl
|
||||
from llama_stack.distribution.client import LlamaStackAsLibraryClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
# Use LlamaStackAsLibraryClient for inline testing
|
||||
return LlamaStackAsLibraryClient()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthetic_data_kit_provider_integration(client: LlamaStackAsLibraryClient):
|
||||
provider = await get_provider_impl()
|
||||
assert isinstance(provider, SyntheticDataGeneration)
|
||||
|
||||
# Test single message generation
|
||||
dialogs = [
|
||||
Message(role="user", content="What is artificial intelligence?"),
|
||||
]
|
||||
|
||||
response = await provider.synthetic_data_generate(
|
||||
dialogs=dialogs,
|
||||
filtering_function=FilteringFunction.none,
|
||||
)
|
||||
|
||||
assert response.synthetic_data is not None
|
||||
assert len(response.synthetic_data) > 0
|
||||
assert all(isinstance(item, dict) for item in response.synthetic_data)
|
||||
assert all("question" in item and "answer" in item for item in response.synthetic_data)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthetic_data_kit_provider_with_filtering(client: LlamaStackAsLibraryClient):
|
||||
provider = await get_provider_impl()
|
||||
|
||||
# Test generation with filtering
|
||||
dialogs = [
|
||||
Message(role="user", content="Explain quantum computing."),
|
||||
Message(role="assistant", content="Quantum computing uses quantum mechanics..."),
|
||||
]
|
||||
|
||||
response = await provider.synthetic_data_generate(
|
||||
dialogs=dialogs,
|
||||
filtering_function=FilteringFunction.top_k,
|
||||
)
|
||||
|
||||
assert response.synthetic_data is not None
|
||||
assert len(response.synthetic_data) > 0
|
||||
assert response.statistics is not None
|
||||
assert "threshold" in response.statistics
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthetic_data_kit_provider_error_handling(client: LlamaStackAsLibraryClient):
|
||||
provider = await get_provider_impl()
|
||||
|
||||
# Test with empty dialogs
|
||||
with pytest.raises(ValueError):
|
||||
await provider.synthetic_data_generate(
|
||||
dialogs=[],
|
||||
filtering_function=FilteringFunction.none,
|
||||
)
|
||||
|
||||
# Test with invalid model
|
||||
with pytest.raises(RuntimeError):
|
||||
await provider.synthetic_data_generate(
|
||||
dialogs=[Message(role="user", content="Test")],
|
||||
filtering_function=FilteringFunction.none,
|
||||
model="invalid-model",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthetic_data_kit_provider_with_env_config(client: LlamaStackAsLibraryClient):
|
||||
# Set environment variables for testing
|
||||
os.environ["SYNTHETIC_DATA_KIT_MODEL"] = "meta-llama/Llama-3.2-7B-Instruct"
|
||||
|
||||
provider = await get_provider_impl()
|
||||
dialogs = [
|
||||
Message(role="user", content="What is deep learning?"),
|
||||
Message(role="assistant", content="Deep learning is a subset of machine learning..."),
|
||||
]
|
||||
|
||||
response = await provider.synthetic_data_generate(
|
||||
dialogs=dialogs,
|
||||
filtering_function=FilteringFunction.none,
|
||||
)
|
||||
|
||||
assert response.synthetic_data is not None
|
||||
assert len(response.synthetic_data) > 0
|
||||
# Clean up environment
|
||||
del os.environ["SYNTHETIC_DATA_KIT_MODEL"]
|
|
@ -0,0 +1,132 @@
|
|||
import os
|
||||
import pytest
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.synthetic_data_generation import (
|
||||
SyntheticDataGeneration,
|
||||
SyntheticDataGenerationResponse,
|
||||
FilteringFunction,
|
||||
)
|
||||
from llama_stack.providers.inline.synthetic_data_generation.synthetic_data_kit.config import (
|
||||
SyntheticDataKitConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.synthetic_data_generation.synthetic_data_kit.synthetic_data_kit import (
|
||||
SyntheticDataKitProvider,
|
||||
)
|
||||
|
||||
|
||||
def test_config_defaults():
|
||||
"""Test default configuration values"""
|
||||
config = SyntheticDataKitConfig()
|
||||
assert config.llm["provider"] == "vllm"
|
||||
assert config.llm["model"] == "meta-llama/Llama-3.2-3B-Instruct"
|
||||
assert config.vllm["api_base"] == "http://localhost:8000/v1"
|
||||
assert config.generation["temperature"] == 0.7
|
||||
assert config.generation["chunk_size"] == 4000
|
||||
assert config.curate["threshold"] == 7.0
|
||||
|
||||
|
||||
def test_sample_run_config():
|
||||
"""Test sample configuration with environment variables"""
|
||||
# Test default configuration
|
||||
config = SyntheticDataKitConfig.sample_run_config()
|
||||
assert isinstance(config, SyntheticDataKitConfig)
|
||||
assert config.llm["model"] == "meta-llama/Llama-3.2-3B-Instruct"
|
||||
|
||||
# Test environment variable override
|
||||
os.environ["SYNTHETIC_DATA_KIT_MODEL"] = "meta-llama/Llama-3.2-7B-Instruct"
|
||||
config = SyntheticDataKitConfig.sample_run_config()
|
||||
assert config.llm["model"] == "meta-llama/Llama-3.2-7B-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sdk():
|
||||
"""Create a mock SDK instance"""
|
||||
with patch("synthetic_data_kit.SyntheticDataKit") as mock:
|
||||
sdk_instance = MagicMock()
|
||||
sdk_instance.create = AsyncMock()
|
||||
sdk_instance.curate = AsyncMock()
|
||||
mock.return_value = sdk_instance
|
||||
yield sdk_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
return SyntheticDataKitConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider(config: SyntheticDataKitConfig, mock_sdk):
|
||||
return SyntheticDataKitProvider(config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthetic_data_generate_basic(provider: SyntheticDataGeneration, mock_sdk):
|
||||
# Setup mock response
|
||||
mock_sdk.create.return_value = {
|
||||
"synthetic_data": [{"question": "What is ML?", "answer": "Machine learning..."}],
|
||||
"statistics": {"count": 1}
|
||||
}
|
||||
|
||||
dialogs = [Message(role="user", content="What is machine learning?")]
|
||||
response = await provider.synthetic_data_generate(
|
||||
dialogs=dialogs,
|
||||
filtering_function=FilteringFunction.none,
|
||||
)
|
||||
|
||||
# Verify SDK was called correctly
|
||||
mock_sdk.create.assert_called_once_with("What is machine learning?", type="qa")
|
||||
assert isinstance(response, SyntheticDataGenerationResponse)
|
||||
assert len(response.synthetic_data) == 1
|
||||
assert response.statistics == {"count": 1}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthetic_data_generate_with_filtering(provider: SyntheticDataGeneration, mock_sdk):
|
||||
# Setup mock responses
|
||||
mock_sdk.create.return_value = {
|
||||
"synthetic_data": [{"question": "What is quantum?", "answer": "Quantum..."}],
|
||||
}
|
||||
mock_sdk.curate.return_value = {
|
||||
"synthetic_data": [{"question": "What is quantum?", "answer": "Quantum..."}],
|
||||
"statistics": {"threshold": 7.5}
|
||||
}
|
||||
|
||||
dialogs = [Message(role="user", content="Explain quantum computing.")]
|
||||
response = await provider.synthetic_data_generate(
|
||||
dialogs=dialogs,
|
||||
filtering_function=FilteringFunction.top_k,
|
||||
)
|
||||
|
||||
# Verify both create and curate were called
|
||||
mock_sdk.create.assert_called_once_with("Explain quantum computing.", type="qa")
|
||||
mock_sdk.curate.assert_called_once()
|
||||
assert isinstance(response, SyntheticDataGenerationResponse)
|
||||
assert response.statistics["threshold"] == 7.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthetic_data_generate_multiple_messages(provider: SyntheticDataGeneration, mock_sdk):
|
||||
mock_sdk.create.return_value = {
|
||||
"synthetic_data": [{"question": "What is deep learning?", "answer": "Deep..."}],
|
||||
"statistics": {"count": 1}
|
||||
}
|
||||
|
||||
dialogs = [
|
||||
Message(role="user", content="What is deep learning?"),
|
||||
Message(role="assistant", content="Deep learning is..."),
|
||||
Message(role="user", content="Can you explain more?")
|
||||
]
|
||||
|
||||
response = await provider.synthetic_data_generate(
|
||||
dialogs=dialogs,
|
||||
filtering_function=FilteringFunction.none,
|
||||
)
|
||||
|
||||
# Verify content was joined correctly
|
||||
expected_content = "What is deep learning?\nDeep learning is...\nCan you explain more?"
|
||||
mock_sdk.create.assert_called_once_with(expected_content, type="qa")
|
||||
assert isinstance(response, SyntheticDataGenerationResponse)
|
||||
assert response.synthetic_data is not None
|
Loading…
Add table
Add a link
Reference in a new issue