mock llamastackclient in unit tests

This commit is contained in:
raspawar 2025-03-24 15:14:17 +05:30
parent 0d4dc06a3c
commit 34db80fb15
4 changed files with 294 additions and 44 deletions

View file

@ -7,16 +7,30 @@
import os
import unittest
import warnings
from unittest.mock import patch
from unittest.mock import patch, AsyncMock, MagicMock
import pytest
import atexit
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigEfficiencyConfig,
TrainingConfigOptimizerConfig,
TrainingConfigEfficiencyConfig,
)
from .mock_llama_stack_client import MockLlamaStackClient
# Create a mock session
mock_session = MagicMock()
mock_session.closed = False
mock_session.close = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock()
patch("aiohttp.ClientSession", return_value=mock_session).start()
atexit.register(lambda: patch.stopall())
class TestNvidiaParameters(unittest.TestCase):
@ -25,8 +39,10 @@ class TestNvidiaParameters(unittest.TestCase):
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002"
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
_ = self.llama_stack_client.initialize()
# Use the mock client
with patch("llama_stack.distribution.library_client.LlamaStackAsLibraryClient", MockLlamaStackClient):
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
_ = self.llama_stack_client.initialize()
self.make_request_patcher = patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
@ -40,6 +56,10 @@ class TestNvidiaParameters(unittest.TestCase):
}
def tearDown(self):
# Close the client if it has a close method
if hasattr(self.llama_stack_client, "close"):
self.llama_stack_client.close()
self.make_request_patcher.stop()
def _assert_request_params(self, expected_json):
@ -54,7 +74,7 @@ class TestNvidiaParameters(unittest.TestCase):
else:
assert actual_json[key] == value
def test_optional_parameter_passed(self):
def test_customizer_parameters_passed(self):
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
custom_adapter_dim = 32 # Different from default of 8
algorithm_config = LoraFinetuningConfig(
@ -76,15 +96,27 @@ class TestNvidiaParameters(unittest.TestCase):
optimizer_config=optimizer_config,
)
self.llama_stack_client.post_training.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.llama_stack_client.post_training.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
warning_texts = [str(warning.message) for warning in w]
fields = [
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
self._assert_request_params(
{
@ -97,7 +129,7 @@ class TestNvidiaParameters(unittest.TestCase):
}
)
def test_required_parameter_passed(self):
def test_required_parameters_passed(self):
"""Test scenario 2: When required parameters are passed."""
required_model = "meta-llama/Llama-3.1-8B-Instruct"
required_dataset_id = "required-dataset"
@ -127,25 +159,40 @@ class TestNvidiaParameters(unittest.TestCase):
optimizer_config=optimizer_config,
)
self.llama_stack_client.post_training.supervised_fine_tune(
job_uuid=required_job_uuid, # Required parameter
model=required_model, # Required parameter
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
# catch required unsupported parameters warnings
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.llama_stack_client.post_training.supervised_fine_tune(
job_uuid=required_job_uuid, # Required parameter
model=required_model, # Required parameter
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
self.mock_make_request.assert_called_once()
call_args = self.mock_make_request.call_args
self.mock_make_request.assert_called_once()
call_args = self.mock_make_request.call_args
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct"
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct"
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
warning_texts = [str(warning.message) for warning in w]
fields = [
"rank",
"use_dora",
"quantize_base",
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
def test_unsupported_parameters_warning(self):
"""Test that warnings are raised for unsupported parameters."""
# Create a training config with unsupported parameters
data_config = TrainingConfigDataConfig(
dataset_id="test-dataset",
batch_size=8,
@ -215,10 +262,32 @@ class TestNvidiaParameters(unittest.TestCase):
"gradient_accumulation_steps",
"max_validation_steps",
"dtype",
# required unsupported parameters
"rank",
"use_dora",
"quantize_base",
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
@pytest.fixture
def llama_stack_client():
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002"
with patch("llama_stack.distribution.library_client.LlamaStackAsLibraryClient", MockLlamaStackClient):
client = LlamaStackAsLibraryClient("nvidia")
_ = client.initialize()
yield client
if hasattr(client, "close"):
client.close()
if __name__ == "__main__":
unittest.main()