unit test update, warnings for unsupported parameters

This commit is contained in:
Ubuntu 2025-03-12 14:17:26 +00:00 committed by raspawar
parent 152261a249
commit bd9b6a6e00
5 changed files with 413 additions and 429 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
import warnings
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -15,27 +16,27 @@ class NvidiaPostTrainingConfig(BaseModel):
api_key: Optional[str] = Field( api_key: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_API_KEY"), default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
description="The NVIDIA API key, only needed of using the hosted service", description="The NVIDIA API key.",
) )
user_id: Optional[str] = Field( user_id: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_USER_ID", "llama-stack-user"), default_factory=lambda: os.getenv("NVIDIA_USER_ID", "llama-stack-user"),
description="The NVIDIA user ID, only needed of using the hosted service", description="The NVIDIA user ID.",
) )
dataset_namespace: Optional[str] = Field( dataset_namespace: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"), default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"),
description="The NVIDIA dataset namespace, only needed of using the hosted service", description="The NVIDIA dataset namespace.",
) )
access_policies: Optional[dict] = Field( access_policies: Optional[dict] = Field(
default_factory=lambda: os.getenv("NVIDIA_ACCESS_POLICIES", {}), default_factory=lambda: os.getenv("NVIDIA_ACCESS_POLICIES", {}),
description="The NVIDIA access policies, only needed of using the hosted service", description="The NVIDIA access policies.",
) )
project_id: Optional[str] = Field( project_id: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-project"), default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-project"),
description="The NVIDIA project ID, only needed of using the hosted service", description="The NVIDIA project ID.",
) )
# ToDO: validate this, add default value # ToDO: validate this, add default value
@ -54,11 +55,35 @@ class NvidiaPostTrainingConfig(BaseModel):
description="Maximum number of retries for the NVIDIA Post Training API", description="Maximum number of retries for the NVIDIA Post Training API",
) )
# ToDo: validate this, add default value
output_model_dir: str = Field( output_model_dir: str = Field(
default_factory=lambda: os.getenv("NVIDIA_OUTPUT_MODEL_DIR", "test-example-model@v1"), default_factory=lambda: os.getenv("NVIDIA_OUTPUT_MODEL_DIR", "test-example-model@v1"),
description="Directory to save the output model", description="Directory to save the output model",
) )
# warning for default values
def __post_init__(self):
default_values = []
if os.getenv("NVIDIA_OUTPUT_MODEL_DIR") is None:
default_values.append("output_model_dir='test-example-model@v1'")
if os.getenv("NVIDIA_PROJECT_ID") is None:
default_values.append("project_id='test-project'")
if os.getenv("NVIDIA_USER_ID") is None:
default_values.append("user_id='llama-stack-user'")
if os.getenv("NVIDIA_DATASET_NAMESPACE") is None:
default_values.append("dataset_namespace='default'")
if os.getenv("NVIDIA_ACCESS_POLICIES") is None:
default_values.append("access_policies='{}'")
if os.getenv("NVIDIA_CUSTOMIZER_URL") is None:
default_values.append("customizer_url='http://nemo.test'")
if default_values:
warnings.warn(
f"Using default values: {', '.join(default_values)}. \
Please set the environment variables to avoid this default behavior.",
stacklevel=2,
)
@classmethod @classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]: def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return { return {

View file

@ -3,6 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import warnings
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Literal, Optional from typing import Any, Dict, List, Literal, Optional
@ -190,6 +191,13 @@ class NvidiaPostTrainingAdapter:
job_uuid: str - Unique identifier for the job job_uuid: str - Unique identifier for the job
hyperparam_search_config: Dict[str, Any] - Configuration for hyperparameter search hyperparam_search_config: Dict[str, Any] - Configuration for hyperparameter search
logger_config: Dict[str, Any] - Configuration for logging logger_config: Dict[str, Any] - Configuration for logging
Environment Variables:
- NVIDIA_PROJECT_ID: ID of the project
- NVIDIA_USER_ID: ID of the user
- NVIDIA_ACCESS_POLICIES: Access policies for the project
- NVIDIA_DATASET_NAMESPACE: Namespace of the dataset
- NVIDIA_OUTPUT_MODEL_DIR: Directory to save the output model
""" """
# map model to nvidia model name # map model to nvidia model name
model_mapping = { model_mapping = {
@ -198,9 +206,30 @@ class NvidiaPostTrainingAdapter:
} }
nvidia_model = model_mapping.get(model, model) nvidia_model = model_mapping.get(model, model)
# Get output model directory from config # Check for unsupported parameters
if checkpoint_dir or hyperparam_search_config or logger_config:
warnings.warn(
"Parameters: {} not supported atm, will be ignored".format(
checkpoint_dir,
)
)
def warn_unsupported_params(config_dict: Dict[str, Any], supported_keys: List[str], config_name: str) -> None:
"""Helper function to warn about unsupported parameters in a config dictionary."""
unsupported_params = [k for k in config_dict.keys() if k not in supported_keys]
if unsupported_params:
warnings.warn(f"Parameters: {unsupported_params} in {config_name} not supported and will be ignored.")
# Check for unsupported parameters
warn_unsupported_params(training_config, ["n_epochs", "data_config", "optimizer_config"], "TrainingConfig")
warn_unsupported_params(training_config["data_config"], ["dataset_id", "batch_size"], "DataConfig")
warn_unsupported_params(training_config["optimizer_config"], ["lr"], "OptimizerConfig")
output_model = self.config.output_model_dir output_model = self.config.output_model_dir
if output_model == "default":
warnings.warn("output_model_dir set via default value, will be ignored")
# Prepare base job configuration # Prepare base job configuration
job_config = { job_config = {
"config": nvidia_model, "config": nvidia_model,
@ -226,6 +255,7 @@ class NvidiaPostTrainingAdapter:
# Extract LoRA-specific parameters # Extract LoRA-specific parameters
lora_config = {k: v for k, v in algorithm_config.items() if k != "type"} lora_config = {k: v for k, v in algorithm_config.items() if k != "type"}
job_config["hyperparameters"]["lora"] = lora_config job_config["hyperparameters"]["lora"] = lora_config
warn_unsupported_params(lora_config, ["adapter_dim", "adapter_dropout"], "LoRA config")
else: else:
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}") raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")

View file

@ -13,47 +13,13 @@
import logging import logging
from typing import Tuple from typing import Tuple
import httpx
from .config import NvidiaPostTrainingConfig from .config import NvidiaPostTrainingConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def _get_health(url: str) -> Tuple[bool, bool]: # ToDo: implement post health checks for customizer are enabled
""" async def _get_health(url: str) -> Tuple[bool, bool]: ...
Query {url}/v1/health/{live,ready} to check if the server is running and ready
Args:
url (str): URL of the server
Returns:
Tuple[bool, bool]: (is_live, is_ready)
"""
async with httpx.AsyncClient() as client:
live = await client.get(f"{url}/v1/health/live")
ready = await client.get(f"{url}/v1/health/ready")
return live.status_code == 200, ready.status_code == 200
async def check_health(config: NvidiaPostTrainingConfig) -> None: async def check_health(config: NvidiaPostTrainingConfig) -> None: ...
"""
Check if the server is running and ready
Args:
url (str): URL of the server
Raises:
RuntimeError: If the server is not running or ready
"""
if not _is_nvidia_hosted(config):
logger.info("Checking NVIDIA NIM health...")
try:
is_live, is_ready = await _get_health(config.url)
if not is_live:
raise ConnectionError("NVIDIA NIM is not running")
if not is_ready:
raise ConnectionError("NVIDIA NIM is not ready")
# TODO(mf): should we wait for the server to be ready?
except httpx.ConnectError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e

View file

@ -1,386 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigOptimizerConfig,
)
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
POST_TRAINING_PROVIDER_TYPES = ["remote::nvidia"]
@pytest.mark.integration
@pytest.fixture(scope="session")
def post_training_provider_available(llama_stack_client):
providers = llama_stack_client.providers.list()
post_training_providers = [p for p in providers if p.provider_type in POST_TRAINING_PROVIDER_TYPES]
return len(post_training_providers) > 0
@pytest.mark.integration
def test_post_training_provider_registration(llama_stack_client, post_training_provider_available):
"""Check if post_training is in the api list.
This is a sanity check to ensure the provider is registered."""
if not post_training_provider_available:
pytest.skip("post training provider not available")
providers = llama_stack_client.providers.list()
post_training_providers = [p for p in providers if p.provider_type in POST_TRAINING_PROVIDER_TYPES]
assert len(post_training_providers) > 0
class TestNvidiaPostTraining(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
os.environ["NVIDIA_BASE_URL"] = "http://nim.test"
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
self.llama_stack_client.initialize = MagicMock(return_value=None)
_ = self.llama_stack_client.initialize()
@patch("requests.post")
def test_supervised_fine_tune(self, mock_post):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884",
"config": {
"schema_version": "1.0",
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.569837",
"custom_fields": {},
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
"model_path": "llama-3_1-8b-instruct",
"training_types": [],
"finetuning_types": ["lora"],
"precision": "bf16",
"num_gpus": 4,
"num_nodes": 1,
"micro_batch_size": 1,
"tensor_parallel_size": 1,
"max_seq_length": 4096,
},
"dataset": {
"schema_version": "1.0",
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.542660",
"custom_fields": {},
"name": "default/sample-basic-test",
"version_id": "main",
"version_tags": [],
},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16},
},
"output_model": "default/job-1234",
"status": "created",
"project": "default",
"custom_fields": {},
"ownership": {"created_by": "me", "access_policies": {}},
}
mock_post.return_value = mock_response
algorithm_config = LoraFinetuningConfig(type="LoRA", adapter_dim=16)
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
with patch.object(
self.llama_stack_client.post_training,
"supervised_fine_tune",
return_value={
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"status": "created",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"dataset_id": "sample-basic-test",
"output_model": "default/job-1234",
},
):
training_job = self.llama_stack_client.post_training.supervised_fine_tune(
job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
self.assertEqual(training_job["id"], "cust-JGTaMbJMdqjJU8WbQdN9Q2")
self.assertEqual(training_job["status"], "created")
self.assertEqual(training_job["model"], "meta-llama/Llama-3.1-8B-Instruct")
self.assertEqual(training_job["dataset_id"], "sample-basic-test")
@patch("requests.get")
def test_get_job_status(self, mock_get):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832",
"status": "completed",
"steps_completed": 1210,
"epochs_completed": 2,
"percentage_done": 100.0,
"best_epoch": 2,
"train_loss": 1.718016266822815,
"val_loss": 1.8661999702453613,
}
mock_get.return_value = mock_response
with patch.object(
self.llama_stack_client.post_training.job,
"status",
return_value={
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"status": "completed",
"created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832",
"steps_completed": 1210,
"epochs_completed": 2,
"percentage_done": 100.0,
"best_epoch": 2,
"train_loss": 1.718016266822815,
"val_loss": 1.8661999702453613,
},
):
status = self.llama_stack_client.post_training.job.status("cust-JGTaMbJMdqjJU8WbQdN9Q2")
self.assertEqual(status["status"], "completed")
self.assertEqual(status["steps_completed"], 1210)
self.assertEqual(status["epochs_completed"], 2)
self.assertEqual(status["percentage_done"], 100.0)
self.assertEqual(status["best_epoch"], 2)
self.assertEqual(status["train_loss"], 1.718016266822815)
self.assertEqual(status["val_loss"], 1.8661999702453613)
@patch("requests.get")
def test_get_job(self, mock_get):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:21:19.852832",
"config": {"name": "meta-llama/Llama-3.1-8B-Instruct", "base_model": "meta-llama/Llama-3.1-8B-Instruct"},
"dataset": {"name": "default/sample-basic-test"},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16},
},
"output_model": "default/job-1234",
"status": "completed",
"project": "default",
}
mock_get.return_value = mock_response
client = MagicMock()
with patch.object(
client.post_training,
"get_job",
return_value={
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"status": "completed",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:21:19.852832",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"dataset_id": "sample-basic-test",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"adapter_dim": 16,
"output_model": "default/job-1234",
},
):
job = client.post_training.get_job("cust-JGTaMbJMdqjJU8WbQdN9Q2")
self.assertEqual(job["id"], "cust-JGTaMbJMdqjJU8WbQdN9Q2")
self.assertEqual(job["status"], "completed")
self.assertEqual(job["model"], "meta-llama/Llama-3.1-8B-Instruct")
self.assertEqual(job["dataset_id"], "sample-basic-test")
self.assertEqual(job["batch_size"], 16)
self.assertEqual(job["epochs"], 2)
self.assertEqual(job["learning_rate"], 0.0001)
self.assertEqual(job["adapter_dim"], 16)
self.assertEqual(job["output_model"], "default/job-1234")
@patch("requests.delete")
def test_cancel_job(self, mock_delete):
mock_response = MagicMock()
mock_response.status_code = 200
mock_delete.return_value = mock_response
client = MagicMock()
with patch.object(client.post_training, "cancel_job", return_value=True):
result = client.post_training.cancel_job("cust-JGTaMbJMdqjJU8WbQdN9Q2")
self.assertTrue(result)
@pytest.mark.asyncio
@patch("aiohttp.ClientSession.post")
async def test_async_supervised_fine_tune(self, mock_post):
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(
return_value={
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"status": "created",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"dataset_id": "sample-basic-test",
"output_model": "default/job-1234",
}
)
mock_post.return_value.__aenter__.return_value = mock_response
client = MagicMock()
algorithm_config = LoraFinetuningConfig(type="LoRA", adapter_dim=16)
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
with patch.object(
client.post_training,
"supervised_fine_tune_async",
AsyncMock(
return_value={
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"status": "created",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"dataset_id": "sample-basic-test",
"output_model": "default/job-1234",
}
),
):
training_job = await client.post_training.supervised_fine_tune_async(
job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
self.assertEqual(training_job["id"], "cust-JGTaMbJMdqjJU8WbQdN9Q2")
self.assertEqual(training_job["status"], "created")
self.assertEqual(training_job["model"], "meta-llama/Llama-3.1-8B-Instruct")
self.assertEqual(training_job["dataset_id"], "sample-basic-test")
@pytest.mark.asyncio
@patch("aiohttp.ClientSession.post")
async def test_inference_with_fine_tuned_model(self, mock_post):
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(
return_value={
"id": "cmpl-123456",
"object": "text_completion",
"created": 1677858242,
"model": "job-1234",
"choices": [
{
"text": "The next GTC will take place in the middle of March, 2023.",
"index": 0,
"logprobs": None,
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 100, "completion_tokens": 12, "total_tokens": 112},
}
)
mock_post.return_value.__aenter__.return_value = mock_response
client = MagicMock()
with patch.object(
client.inference,
"completion",
AsyncMock(
return_value={
"id": "cmpl-123456",
"object": "text_completion",
"created": 1677858242,
"model": "job-1234",
"choices": [
{
"text": "The next GTC will take place in the middle of March, 2023.",
"index": 0,
"logprobs": None,
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 100, "completion_tokens": 12, "total_tokens": 112},
}
),
):
response = await client.inference.completion(
content="When is the upcoming GTC event? GTC 2018 attracted over 8,400 attendees. Due to the COVID pandemic of 2020, GTC 2020 was converted to a digital event and drew roughly 59,000 registrants. The 2021 GTC keynote, which was streamed on YouTube on April 12, included a portion that was made with CGI using the Nvidia Omniverse real-time rendering platform. This next GTC will take place in the middle of March, 2023. Answer: ",
stream=False,
model_id="job-1234",
sampling_params={
"max_tokens": 128,
},
)
self.assertEqual(response["model"], "job-1234")
self.assertEqual(
response["choices"][0]["text"], "The next GTC will take place in the middle of March, 2023."
)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,349 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
from llama_stack_client.types.post_training.job_status_response import JobStatusResponse
from llama_stack_client.types.post_training_job import PostTrainingJob
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigOptimizerConfig,
)
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
class TestNvidiaPostTraining(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002" # mocking llama stack base url
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
_ = self.llama_stack_client.initialize()
## ToDo: post health checks for customizer are enabled, include test cases for NVIDIA_CUSTOMIZER_URL
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
"""Helper method to verify request details in mock calls."""
call_args = mock_call.call_args
if expected_method and expected_path:
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
assert call_args[0] == (expected_method, expected_path)
else:
assert call_args[1]["method"] == expected_method
assert call_args[1]["path"] == expected_path
if expected_params:
assert call_args[1]["params"] == expected_params
if expected_json:
for key, value in expected_json.items():
assert call_args[1]["json"][key] == value
@patch("llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request")
def test_supervised_fine_tune(self, mock_make_request):
"""Test the supervised fine-tuning API call.
ToDo: add tests for env variables."""
mock_make_request.return_value = {
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884",
"config": {
"schema_version": "1.0",
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.569837",
"custom_fields": {},
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
"model_path": "llama-3_1-8b-instruct",
"training_types": [],
"finetuning_types": ["lora"],
"precision": "bf16",
"num_gpus": 4,
"num_nodes": 1,
"micro_batch_size": 1,
"tensor_parallel_size": 1,
"max_seq_length": 4096,
},
"dataset": {
"schema_version": "1.0",
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.542660",
"custom_fields": {},
"name": "sample-basic-test",
"version_id": "main",
"version_tags": [],
},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
"output_model": "default/job-1234",
"status": "created",
"project": "default",
"custom_fields": {},
"ownership": {"created_by": "me", "access_policies": {}},
}
algorithm_config = LoraFinetuningConfig(type="LoRA", adapter_dim=16, adapter_dropout=0.1)
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
training_job = self.llama_stack_client.post_training.supervised_fine_tune(
job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
# check the output is a PostTrainingJob
# Note: Although the type is PostTrainingJob: llama_stack.apis.post_training.PostTrainingJob,
# post llama_stack_client initialization it gets translated to llama_stack_client.types.post_training_job.PostTrainingJob
assert isinstance(training_job, PostTrainingJob)
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
mock_make_request.assert_called_once()
self._assert_request(
mock_make_request,
"POST",
"/v1/customization/jobs",
expected_json={
"config": "meta/llama-3.1-8b-instruct",
"dataset": {"name": "sample-basic-test", "namespace": ""},
"hyperparameters": {
"training_type": "sft",
"finetuning_type": "lora",
"epochs": 2,
"batch_size": 16,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
},
)
def test_supervised_fine_tune_with_qat(self):
algorithm_config = QatFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
# This will raise NotImplementedError since QAT is not supported
with self.assertRaises(NotImplementedError):
self.llama_stack_client.post_training.supervised_fine_tune(
job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
@patch("llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request")
def test_get_job_status(self, mock_make_request):
mock_make_request.return_value = {
"created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832",
"status": "completed",
"steps_completed": 1210,
"epochs_completed": 2,
"percentage_done": 100.0,
"best_epoch": 2,
"train_loss": 1.718016266822815,
"val_loss": 1.8661999702453613,
}
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
status = self.llama_stack_client.post_training.job.status(job_uuid=job_id)
assert isinstance(status, JobStatusResponse)
assert status.status == "completed"
assert status.steps_completed == 1210
assert status.epochs_completed == 2
assert status.percentage_done == 100.0
assert status.best_epoch == 2
assert status.train_loss == 1.718016266822815
assert status.val_loss == 1.8661999702453613
mock_make_request.assert_called_once()
self._assert_request(
mock_make_request, "GET", f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id}
)
@patch("llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request")
def test_get_job(self, mock_make_request):
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
mock_make_request.return_value = {
"data": [
{
"id": job_id,
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:21:19.852832",
"config": {
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
},
"dataset": {"name": "default/sample-basic-test"},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
"output_model": "default/job-1234",
"status": "completed",
"project": "default",
}
]
}
jobs = self.llama_stack_client.post_training.job.list()
assert isinstance(jobs, list)
assert len(jobs) == 1
job = jobs[0]
assert job.job_uuid == job_id
assert job.status == "completed"
mock_make_request.assert_called_once()
self._assert_request(
mock_make_request,
"GET",
"/v1/customization/jobs",
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
)
@patch("llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request")
def test_cancel_job(self, mock_make_request):
mock_make_request.return_value = {} # Empty response for successful cancellation
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
result = self.llama_stack_client.post_training.job.cancel(job_uuid=job_id)
assert result is None
# Verify the correct request was made
mock_make_request.assert_called_once()
self._assert_request(
mock_make_request, "POST", f"/v1/customization/jobs/{job_id}/cancel", expected_params={"job_id": job_id}
)
@pytest.mark.asyncio
@patch("llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request")
async def test_async_supervised_fine_tune(self, mock_make_request):
mock_make_request.return_value = {
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"status": "created",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"dataset_id": "sample-basic-test",
"output_model": "default/job-1234",
}
algorithm_config = LoraFinetuningConfig(type="LoRA", adapter_dim=16, adapter_dropout=0.1)
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
training_job = await self.llama_stack_client.post_training.supervised_fine_tune_async(
job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
assert training_job["job_uuid"] == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
assert training_job["status"] == "created"
mock_make_request.assert_called_once()
call_args = mock_make_request.call_args
assert call_args[1]["method"] == "POST"
assert call_args[1]["path"] == "/v1/customization/jobs"
@pytest.mark.asyncio
@patch("aiohttp.ClientSession.post")
async def test_inference_with_fine_tuned_model(self, mock_post):
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(
return_value={
"id": "cmpl-123456",
"object": "text_completion",
"created": 1677858242,
"model": "job-1234",
"choices": [
{
"text": "The next GTC will take place in the middle of March, 2023.",
"index": 0,
"logprobs": None,
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 100, "completion_tokens": 12, "total_tokens": 112},
}
)
mock_post.return_value.__aenter__.return_value = mock_response
response = await self.llama_stack_client.inference.completion(
content="When is the upcoming GTC event? GTC 2018 attracted over 8,400 attendees. Due to the COVID pandemic of 2020, GTC 2020 was converted to a digital event and drew roughly 59,000 registrants. The 2021 GTC keynote, which was streamed on YouTube on April 12, included a portion that was made with CGI using the Nvidia Omniverse real-time rendering platform. This next GTC will take place in the middle of March, 2023. Answer: ",
stream=False,
model_id="job-1234",
sampling_params={
"max_tokens": 128,
},
)
assert response["model"] == "job-1234"
assert response["choices"][0]["text"] == "The next GTC will take place in the middle of March, 2023."
if __name__ == "__main__":
unittest.main()