mock llamastackclient in unit tests

This commit is contained in:
raspawar 2025-03-24 15:14:17 +05:30
parent 0d4dc06a3c
commit 34db80fb15
4 changed files with 294 additions and 44 deletions

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
mock_session = MagicMock()
mock_session.closed = False
mock_session.close = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock()
@pytest.fixture(scope="session", autouse=True)
def patch_aiohttp_session():
with patch("aiohttp.ClientSession", return_value=mock_session):
yield

View file

@ -0,0 +1,147 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack_client.types.algorithm_config_param import QatFinetuningConfig
from llama_stack_client.types.post_training.job_status_response import JobStatusResponse
from llama_stack_client.types.post_training_job import PostTrainingJob
class MockLlamaStackClient:
"""Mock client for testing NVIDIA post-training functionality."""
def __init__(self, provider="nvidia"):
self.provider = provider
self.post_training = MockPostTraining()
self.inference = MockInference()
self._session = None
def initialize(self):
"""Mock initialization method."""
return True
def close(self):
"""Close any open resources."""
pass
class MockPostTraining:
"""Mock post-training module."""
def __init__(self):
self.job = MockPostTrainingJob()
def supervised_fine_tune(
self,
job_uuid,
model,
checkpoint_dir,
algorithm_config,
training_config,
logger_config,
hyperparam_search_config,
):
"""Mock supervised fine-tuning method."""
if isinstance(algorithm_config, QatFinetuningConfig):
raise NotImplementedError("QAT fine-tuning is not supported by NVIDIA provider")
# Return a mock PostTrainingJob
return PostTrainingJob(
job_uuid="cust-JGTaMbJMdqjJU8WbQdN9Q2",
status="created",
created_at="2024-12-09T04:06:28.542884",
updated_at="2024-12-09T04:06:28.542884",
model=model,
dataset_id=training_config.data_config.dataset_id,
output_model="default/job-1234",
)
async def supervised_fine_tune_async(
self,
job_uuid,
model,
checkpoint_dir,
algorithm_config,
training_config,
logger_config,
hyperparam_search_config,
):
"""Mock async supervised fine-tuning method."""
if isinstance(algorithm_config, QatFinetuningConfig):
raise NotImplementedError("QAT fine-tuning is not supported by NVIDIA provider")
# Return a mock response dictionary
return {
"job_uuid": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"status": "created",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884",
"model": model,
"dataset_id": training_config.data_config.dataset_id,
"output_model": "default/job-1234",
}
class MockPostTrainingJob:
"""Mock post-training job module."""
def status(self, job_uuid):
"""Mock job status method."""
return JobStatusResponse(
status="completed",
steps_completed=1210,
epochs_completed=2,
percentage_done=100.0,
best_epoch=2,
train_loss=1.718016266822815,
val_loss=1.8661999702453613,
)
def list(self):
"""Mock job list method."""
return [
PostTrainingJob(
job_uuid="cust-JGTaMbJMdqjJU8WbQdN9Q2",
status="completed",
created_at="2024-12-09T04:06:28.542884",
updated_at="2024-12-09T04:21:19.852832",
model="meta-llama/Llama-3.1-8B-Instruct",
dataset_id="sample-basic-test",
output_model="default/job-1234",
)
]
def cancel(self, job_uuid):
"""Mock job cancel method."""
return None
class MockInference:
"""Mock inference module."""
async def completion(
self,
content,
stream=False,
model_id=None,
sampling_params=None,
):
"""Mock completion method."""
return {
"id": "cmpl-123456",
"object": "text_completion",
"created": 1677858242,
"model": model_id,
"choices": [
{
"text": "The next GTC will take place in the middle of March, 2023.",
"index": 0,
"logprobs": None,
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 100, "completion_tokens": 12, "total_tokens": 112},
}

View file

@ -7,16 +7,30 @@
import os
import unittest
import warnings
from unittest.mock import patch
from unittest.mock import patch, AsyncMock, MagicMock
import pytest
import atexit
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigEfficiencyConfig,
TrainingConfigOptimizerConfig,
TrainingConfigEfficiencyConfig,
)
from .mock_llama_stack_client import MockLlamaStackClient
# Create a mock session
mock_session = MagicMock()
mock_session.closed = False
mock_session.close = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock()
patch("aiohttp.ClientSession", return_value=mock_session).start()
atexit.register(lambda: patch.stopall())
class TestNvidiaParameters(unittest.TestCase):
@ -25,8 +39,10 @@ class TestNvidiaParameters(unittest.TestCase):
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002"
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
_ = self.llama_stack_client.initialize()
# Use the mock client
with patch("llama_stack.distribution.library_client.LlamaStackAsLibraryClient", MockLlamaStackClient):
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
_ = self.llama_stack_client.initialize()
self.make_request_patcher = patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
@ -40,6 +56,10 @@ class TestNvidiaParameters(unittest.TestCase):
}
def tearDown(self):
# Close the client if it has a close method
if hasattr(self.llama_stack_client, "close"):
self.llama_stack_client.close()
self.make_request_patcher.stop()
def _assert_request_params(self, expected_json):
@ -54,7 +74,7 @@ class TestNvidiaParameters(unittest.TestCase):
else:
assert actual_json[key] == value
def test_optional_parameter_passed(self):
def test_customizer_parameters_passed(self):
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
custom_adapter_dim = 32 # Different from default of 8
algorithm_config = LoraFinetuningConfig(
@ -76,15 +96,27 @@ class TestNvidiaParameters(unittest.TestCase):
optimizer_config=optimizer_config,
)
self.llama_stack_client.post_training.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.llama_stack_client.post_training.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
warning_texts = [str(warning.message) for warning in w]
fields = [
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
self._assert_request_params(
{
@ -97,7 +129,7 @@ class TestNvidiaParameters(unittest.TestCase):
}
)
def test_required_parameter_passed(self):
def test_required_parameters_passed(self):
"""Test scenario 2: When required parameters are passed."""
required_model = "meta-llama/Llama-3.1-8B-Instruct"
required_dataset_id = "required-dataset"
@ -127,25 +159,40 @@ class TestNvidiaParameters(unittest.TestCase):
optimizer_config=optimizer_config,
)
self.llama_stack_client.post_training.supervised_fine_tune(
job_uuid=required_job_uuid, # Required parameter
model=required_model, # Required parameter
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
# catch required unsupported parameters warnings
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.llama_stack_client.post_training.supervised_fine_tune(
job_uuid=required_job_uuid, # Required parameter
model=required_model, # Required parameter
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
self.mock_make_request.assert_called_once()
call_args = self.mock_make_request.call_args
self.mock_make_request.assert_called_once()
call_args = self.mock_make_request.call_args
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct"
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct"
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
warning_texts = [str(warning.message) for warning in w]
fields = [
"rank",
"use_dora",
"quantize_base",
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
def test_unsupported_parameters_warning(self):
"""Test that warnings are raised for unsupported parameters."""
# Create a training config with unsupported parameters
data_config = TrainingConfigDataConfig(
dataset_id="test-dataset",
batch_size=8,
@ -215,10 +262,32 @@ class TestNvidiaParameters(unittest.TestCase):
"gradient_accumulation_steps",
"max_validation_steps",
"dtype",
# required unsupported parameters
"rank",
"use_dora",
"quantize_base",
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
@pytest.fixture
def llama_stack_client():
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002"
with patch("llama_stack.distribution.library_client.LlamaStackAsLibraryClient", MockLlamaStackClient):
client = LlamaStackAsLibraryClient("nvidia")
_ = client.initialize()
yield client
if hasattr(client, "close"):
client.close()
if __name__ == "__main__":
unittest.main()

View file

@ -7,7 +7,7 @@
import os
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
import warnings
import pytest
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
from llama_stack_client.types.post_training.job_status_response import JobStatusResponse
@ -19,6 +19,7 @@ from llama_stack_client.types.post_training_supervised_fine_tune_params import (
)
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from .mock_llama_stack_client import MockLlamaStackClient
class TestNvidiaPostTraining(unittest.TestCase):
@ -28,10 +29,9 @@ class TestNvidiaPostTraining(unittest.TestCase):
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002" # mocking llama stack base url
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
_ = self.llama_stack_client.initialize()
## ToDo: post health checks for customizer are enabled, include test cases for NVIDIA_CUSTOMIZER_URL
with patch("llama_stack.distribution.library_client.LlamaStackAsLibraryClient", MockLlamaStackClient):
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
_ = self.llama_stack_client.initialize()
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
"""Helper method to verify request details in mock calls."""
@ -125,15 +125,29 @@ class TestNvidiaPostTraining(unittest.TestCase):
optimizer_config=optimizer_config,
)
training_job = self.llama_stack_client.post_training.supervised_fine_tune(
job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
training_job = self.llama_stack_client.post_training.supervised_fine_tune(
job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
# required lora config unsupported parameters warnings
fields = [
"apply_lora_to_mlp",
"rank",
"use_dora",
"lora_attn_modules",
"quantize_base",
"apply_lora_to_output",
]
for field in fields:
assert any(field in str(warning.message) for warning in w)
# check the output is a PostTrainingJob
# Note: Although the type is PostTrainingJob: llama_stack.apis.post_training.PostTrainingJob,
@ -149,7 +163,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
"/v1/customization/jobs",
expected_json={
"config": "meta/llama-3.1-8b-instruct",
"dataset": {"name": "sample-basic-test", "namespace": ""},
"dataset": {"name": "sample-basic-test", "namespace": "default"},
"hyperparameters": {
"training_type": "sft",
"finetuning_type": "lora",