mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
mock llamastackclient in unit tests
This commit is contained in:
parent
0d4dc06a3c
commit
34db80fb15
4 changed files with 294 additions and 44 deletions
20
tests/unit/providers/nvidia/conftest.py
Normal file
20
tests/unit/providers/nvidia/conftest.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.closed = False
|
||||
mock_session.close = AsyncMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def patch_aiohttp_session():
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
yield
|
147
tests/unit/providers/nvidia/mock_llama_stack_client.py
Normal file
147
tests/unit/providers/nvidia/mock_llama_stack_client.py
Normal file
|
@ -0,0 +1,147 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack_client.types.algorithm_config_param import QatFinetuningConfig
|
||||
from llama_stack_client.types.post_training.job_status_response import JobStatusResponse
|
||||
from llama_stack_client.types.post_training_job import PostTrainingJob
|
||||
|
||||
|
||||
class MockLlamaStackClient:
|
||||
"""Mock client for testing NVIDIA post-training functionality."""
|
||||
|
||||
def __init__(self, provider="nvidia"):
|
||||
self.provider = provider
|
||||
self.post_training = MockPostTraining()
|
||||
self.inference = MockInference()
|
||||
self._session = None
|
||||
|
||||
def initialize(self):
|
||||
"""Mock initialization method."""
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
"""Close any open resources."""
|
||||
pass
|
||||
|
||||
|
||||
class MockPostTraining:
|
||||
"""Mock post-training module."""
|
||||
|
||||
def __init__(self):
|
||||
self.job = MockPostTrainingJob()
|
||||
|
||||
def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid,
|
||||
model,
|
||||
checkpoint_dir,
|
||||
algorithm_config,
|
||||
training_config,
|
||||
logger_config,
|
||||
hyperparam_search_config,
|
||||
):
|
||||
"""Mock supervised fine-tuning method."""
|
||||
if isinstance(algorithm_config, QatFinetuningConfig):
|
||||
raise NotImplementedError("QAT fine-tuning is not supported by NVIDIA provider")
|
||||
|
||||
# Return a mock PostTrainingJob
|
||||
return PostTrainingJob(
|
||||
job_uuid="cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||
status="created",
|
||||
created_at="2024-12-09T04:06:28.542884",
|
||||
updated_at="2024-12-09T04:06:28.542884",
|
||||
model=model,
|
||||
dataset_id=training_config.data_config.dataset_id,
|
||||
output_model="default/job-1234",
|
||||
)
|
||||
|
||||
async def supervised_fine_tune_async(
|
||||
self,
|
||||
job_uuid,
|
||||
model,
|
||||
checkpoint_dir,
|
||||
algorithm_config,
|
||||
training_config,
|
||||
logger_config,
|
||||
hyperparam_search_config,
|
||||
):
|
||||
"""Mock async supervised fine-tuning method."""
|
||||
if isinstance(algorithm_config, QatFinetuningConfig):
|
||||
raise NotImplementedError("QAT fine-tuning is not supported by NVIDIA provider")
|
||||
|
||||
# Return a mock response dictionary
|
||||
return {
|
||||
"job_uuid": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||
"status": "created",
|
||||
"created_at": "2024-12-09T04:06:28.542884",
|
||||
"updated_at": "2024-12-09T04:06:28.542884",
|
||||
"model": model,
|
||||
"dataset_id": training_config.data_config.dataset_id,
|
||||
"output_model": "default/job-1234",
|
||||
}
|
||||
|
||||
|
||||
class MockPostTrainingJob:
|
||||
"""Mock post-training job module."""
|
||||
|
||||
def status(self, job_uuid):
|
||||
"""Mock job status method."""
|
||||
return JobStatusResponse(
|
||||
status="completed",
|
||||
steps_completed=1210,
|
||||
epochs_completed=2,
|
||||
percentage_done=100.0,
|
||||
best_epoch=2,
|
||||
train_loss=1.718016266822815,
|
||||
val_loss=1.8661999702453613,
|
||||
)
|
||||
|
||||
def list(self):
|
||||
"""Mock job list method."""
|
||||
return [
|
||||
PostTrainingJob(
|
||||
job_uuid="cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||
status="completed",
|
||||
created_at="2024-12-09T04:06:28.542884",
|
||||
updated_at="2024-12-09T04:21:19.852832",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
dataset_id="sample-basic-test",
|
||||
output_model="default/job-1234",
|
||||
)
|
||||
]
|
||||
|
||||
def cancel(self, job_uuid):
|
||||
"""Mock job cancel method."""
|
||||
return None
|
||||
|
||||
|
||||
class MockInference:
|
||||
"""Mock inference module."""
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
content,
|
||||
stream=False,
|
||||
model_id=None,
|
||||
sampling_params=None,
|
||||
):
|
||||
"""Mock completion method."""
|
||||
return {
|
||||
"id": "cmpl-123456",
|
||||
"object": "text_completion",
|
||||
"created": 1677858242,
|
||||
"model": model_id,
|
||||
"choices": [
|
||||
{
|
||||
"text": "The next GTC will take place in the middle of March, 2023.",
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 12, "total_tokens": 112},
|
||||
}
|
|
@ -7,16 +7,30 @@
|
|||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
import pytest
|
||||
import atexit
|
||||
|
||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
|
||||
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
||||
TrainingConfig,
|
||||
TrainingConfigDataConfig,
|
||||
TrainingConfigEfficiencyConfig,
|
||||
TrainingConfigOptimizerConfig,
|
||||
TrainingConfigEfficiencyConfig,
|
||||
)
|
||||
from .mock_llama_stack_client import MockLlamaStackClient
|
||||
|
||||
# Create a mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session.closed = False
|
||||
mock_session.close = AsyncMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock()
|
||||
|
||||
patch("aiohttp.ClientSession", return_value=mock_session).start()
|
||||
|
||||
atexit.register(lambda: patch.stopall())
|
||||
|
||||
|
||||
class TestNvidiaParameters(unittest.TestCase):
|
||||
|
@ -25,8 +39,10 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002"
|
||||
|
||||
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
|
||||
_ = self.llama_stack_client.initialize()
|
||||
# Use the mock client
|
||||
with patch("llama_stack.distribution.library_client.LlamaStackAsLibraryClient", MockLlamaStackClient):
|
||||
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
|
||||
_ = self.llama_stack_client.initialize()
|
||||
|
||||
self.make_request_patcher = patch(
|
||||
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
|
||||
|
@ -40,6 +56,10 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
}
|
||||
|
||||
def tearDown(self):
|
||||
# Close the client if it has a close method
|
||||
if hasattr(self.llama_stack_client, "close"):
|
||||
self.llama_stack_client.close()
|
||||
|
||||
self.make_request_patcher.stop()
|
||||
|
||||
def _assert_request_params(self, expected_json):
|
||||
|
@ -54,7 +74,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
else:
|
||||
assert actual_json[key] == value
|
||||
|
||||
def test_optional_parameter_passed(self):
|
||||
def test_customizer_parameters_passed(self):
|
||||
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
|
||||
custom_adapter_dim = 32 # Different from default of 8
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
|
@ -76,15 +96,27 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
self.llama_stack_client.post_training.supervised_fine_tune(
|
||||
job_uuid="test-job",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
self.llama_stack_client.post_training.supervised_fine_tune(
|
||||
job_uuid="test-job",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
|
||||
warning_texts = [str(warning.message) for warning in w]
|
||||
|
||||
fields = [
|
||||
"apply_lora_to_output",
|
||||
"lora_attn_modules",
|
||||
"apply_lora_to_mlp",
|
||||
]
|
||||
for field in fields:
|
||||
assert any(field in text for text in warning_texts)
|
||||
|
||||
self._assert_request_params(
|
||||
{
|
||||
|
@ -97,7 +129,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
}
|
||||
)
|
||||
|
||||
def test_required_parameter_passed(self):
|
||||
def test_required_parameters_passed(self):
|
||||
"""Test scenario 2: When required parameters are passed."""
|
||||
required_model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
required_dataset_id = "required-dataset"
|
||||
|
@ -127,25 +159,40 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
self.llama_stack_client.post_training.supervised_fine_tune(
|
||||
job_uuid=required_job_uuid, # Required parameter
|
||||
model=required_model, # Required parameter
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
# catch required unsupported parameters warnings
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
self.llama_stack_client.post_training.supervised_fine_tune(
|
||||
job_uuid=required_job_uuid, # Required parameter
|
||||
model=required_model, # Required parameter
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
call_args = self.mock_make_request.call_args
|
||||
self.mock_make_request.assert_called_once()
|
||||
call_args = self.mock_make_request.call_args
|
||||
|
||||
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct"
|
||||
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
|
||||
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct"
|
||||
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
|
||||
|
||||
warning_texts = [str(warning.message) for warning in w]
|
||||
|
||||
fields = [
|
||||
"rank",
|
||||
"use_dora",
|
||||
"quantize_base",
|
||||
"apply_lora_to_output",
|
||||
"lora_attn_modules",
|
||||
"apply_lora_to_mlp",
|
||||
]
|
||||
for field in fields:
|
||||
assert any(field in text for text in warning_texts)
|
||||
|
||||
def test_unsupported_parameters_warning(self):
|
||||
"""Test that warnings are raised for unsupported parameters."""
|
||||
# Create a training config with unsupported parameters
|
||||
data_config = TrainingConfigDataConfig(
|
||||
dataset_id="test-dataset",
|
||||
batch_size=8,
|
||||
|
@ -215,10 +262,32 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
"gradient_accumulation_steps",
|
||||
"max_validation_steps",
|
||||
"dtype",
|
||||
# required unsupported parameters
|
||||
"rank",
|
||||
"use_dora",
|
||||
"quantize_base",
|
||||
"apply_lora_to_output",
|
||||
"lora_attn_modules",
|
||||
"apply_lora_to_mlp",
|
||||
]
|
||||
for field in fields:
|
||||
assert any(field in text for text in warning_texts)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama_stack_client():
|
||||
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002"
|
||||
|
||||
with patch("llama_stack.distribution.library_client.LlamaStackAsLibraryClient", MockLlamaStackClient):
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
_ = client.initialize()
|
||||
yield client
|
||||
|
||||
if hasattr(client, "close"):
|
||||
client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import os
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import warnings
|
||||
import pytest
|
||||
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
|
||||
from llama_stack_client.types.post_training.job_status_response import JobStatusResponse
|
||||
|
@ -19,6 +19,7 @@ from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
|||
)
|
||||
|
||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
from .mock_llama_stack_client import MockLlamaStackClient
|
||||
|
||||
|
||||
class TestNvidiaPostTraining(unittest.TestCase):
|
||||
|
@ -28,10 +29,9 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
|
||||
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002" # mocking llama stack base url
|
||||
|
||||
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
|
||||
_ = self.llama_stack_client.initialize()
|
||||
|
||||
## ToDo: post health checks for customizer are enabled, include test cases for NVIDIA_CUSTOMIZER_URL
|
||||
with patch("llama_stack.distribution.library_client.LlamaStackAsLibraryClient", MockLlamaStackClient):
|
||||
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
|
||||
_ = self.llama_stack_client.initialize()
|
||||
|
||||
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
|
||||
"""Helper method to verify request details in mock calls."""
|
||||
|
@ -125,15 +125,29 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
training_job = self.llama_stack_client.post_training.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
training_job = self.llama_stack_client.post_training.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
|
||||
# required lora config unsupported parameters warnings
|
||||
fields = [
|
||||
"apply_lora_to_mlp",
|
||||
"rank",
|
||||
"use_dora",
|
||||
"lora_attn_modules",
|
||||
"quantize_base",
|
||||
"apply_lora_to_output",
|
||||
]
|
||||
for field in fields:
|
||||
assert any(field in str(warning.message) for warning in w)
|
||||
|
||||
# check the output is a PostTrainingJob
|
||||
# Note: Although the type is PostTrainingJob: llama_stack.apis.post_training.PostTrainingJob,
|
||||
|
@ -149,7 +163,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
"/v1/customization/jobs",
|
||||
expected_json={
|
||||
"config": "meta/llama-3.1-8b-instruct",
|
||||
"dataset": {"name": "sample-basic-test", "namespace": ""},
|
||||
"dataset": {"name": "sample-basic-test", "namespace": "default"},
|
||||
"hyperparameters": {
|
||||
"training_type": "sft",
|
||||
"finetuning_type": "lora",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue