mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 05:53:53 +00:00
mock llamastackclient in unit tests
This commit is contained in:
parent
0d4dc06a3c
commit
34db80fb15
4 changed files with 294 additions and 44 deletions
|
|
@ -7,16 +7,30 @@
|
|||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
import pytest
|
||||
import atexit
|
||||
|
||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
|
||||
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
||||
TrainingConfig,
|
||||
TrainingConfigDataConfig,
|
||||
TrainingConfigEfficiencyConfig,
|
||||
TrainingConfigOptimizerConfig,
|
||||
TrainingConfigEfficiencyConfig,
|
||||
)
|
||||
from .mock_llama_stack_client import MockLlamaStackClient
|
||||
|
||||
# Create a mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session.closed = False
|
||||
mock_session.close = AsyncMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock()
|
||||
|
||||
patch("aiohttp.ClientSession", return_value=mock_session).start()
|
||||
|
||||
atexit.register(lambda: patch.stopall())
|
||||
|
||||
|
||||
class TestNvidiaParameters(unittest.TestCase):
|
||||
|
|
@ -25,8 +39,10 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002"
|
||||
|
||||
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
|
||||
_ = self.llama_stack_client.initialize()
|
||||
# Use the mock client
|
||||
with patch("llama_stack.distribution.library_client.LlamaStackAsLibraryClient", MockLlamaStackClient):
|
||||
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
|
||||
_ = self.llama_stack_client.initialize()
|
||||
|
||||
self.make_request_patcher = patch(
|
||||
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
|
||||
|
|
@ -40,6 +56,10 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
}
|
||||
|
||||
def tearDown(self):
|
||||
# Close the client if it has a close method
|
||||
if hasattr(self.llama_stack_client, "close"):
|
||||
self.llama_stack_client.close()
|
||||
|
||||
self.make_request_patcher.stop()
|
||||
|
||||
def _assert_request_params(self, expected_json):
|
||||
|
|
@ -54,7 +74,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
else:
|
||||
assert actual_json[key] == value
|
||||
|
||||
def test_optional_parameter_passed(self):
|
||||
def test_customizer_parameters_passed(self):
|
||||
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
|
||||
custom_adapter_dim = 32 # Different from default of 8
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
|
|
@ -76,15 +96,27 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
self.llama_stack_client.post_training.supervised_fine_tune(
|
||||
job_uuid="test-job",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
self.llama_stack_client.post_training.supervised_fine_tune(
|
||||
job_uuid="test-job",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
|
||||
warning_texts = [str(warning.message) for warning in w]
|
||||
|
||||
fields = [
|
||||
"apply_lora_to_output",
|
||||
"lora_attn_modules",
|
||||
"apply_lora_to_mlp",
|
||||
]
|
||||
for field in fields:
|
||||
assert any(field in text for text in warning_texts)
|
||||
|
||||
self._assert_request_params(
|
||||
{
|
||||
|
|
@ -97,7 +129,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
}
|
||||
)
|
||||
|
||||
def test_required_parameter_passed(self):
|
||||
def test_required_parameters_passed(self):
|
||||
"""Test scenario 2: When required parameters are passed."""
|
||||
required_model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
required_dataset_id = "required-dataset"
|
||||
|
|
@ -127,25 +159,40 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
self.llama_stack_client.post_training.supervised_fine_tune(
|
||||
job_uuid=required_job_uuid, # Required parameter
|
||||
model=required_model, # Required parameter
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
# catch required unsupported parameters warnings
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
self.llama_stack_client.post_training.supervised_fine_tune(
|
||||
job_uuid=required_job_uuid, # Required parameter
|
||||
model=required_model, # Required parameter
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
call_args = self.mock_make_request.call_args
|
||||
self.mock_make_request.assert_called_once()
|
||||
call_args = self.mock_make_request.call_args
|
||||
|
||||
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct"
|
||||
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
|
||||
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct"
|
||||
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
|
||||
|
||||
warning_texts = [str(warning.message) for warning in w]
|
||||
|
||||
fields = [
|
||||
"rank",
|
||||
"use_dora",
|
||||
"quantize_base",
|
||||
"apply_lora_to_output",
|
||||
"lora_attn_modules",
|
||||
"apply_lora_to_mlp",
|
||||
]
|
||||
for field in fields:
|
||||
assert any(field in text for text in warning_texts)
|
||||
|
||||
def test_unsupported_parameters_warning(self):
|
||||
"""Test that warnings are raised for unsupported parameters."""
|
||||
# Create a training config with unsupported parameters
|
||||
data_config = TrainingConfigDataConfig(
|
||||
dataset_id="test-dataset",
|
||||
batch_size=8,
|
||||
|
|
@ -215,10 +262,32 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
"gradient_accumulation_steps",
|
||||
"max_validation_steps",
|
||||
"dtype",
|
||||
# required unsupported parameters
|
||||
"rank",
|
||||
"use_dora",
|
||||
"quantize_base",
|
||||
"apply_lora_to_output",
|
||||
"lora_attn_modules",
|
||||
"apply_lora_to_mlp",
|
||||
]
|
||||
for field in fields:
|
||||
assert any(field in text for text in warning_texts)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama_stack_client():
|
||||
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002"
|
||||
|
||||
with patch("llama_stack.distribution.library_client.LlamaStackAsLibraryClient", MockLlamaStackClient):
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
_ = client.initialize()
|
||||
yield client
|
||||
|
||||
if hasattr(client, "close"):
|
||||
client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue