parameter validation, test cases

This commit is contained in:
Ubuntu 2025-03-19 18:57:16 +00:00
parent d7340da7a6
commit 87ce96c1f7
4 changed files with 453 additions and 70 deletions

View file

@ -9,6 +9,8 @@ from typing import Any, Dict, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
# TODO: add default values for all fields
class NvidiaPostTrainingConfig(BaseModel): class NvidiaPostTrainingConfig(BaseModel):
"""Configuration for NVIDIA Post Training implementation.""" """Configuration for NVIDIA Post Training implementation."""
@ -58,3 +60,54 @@ class NvidiaPostTrainingConfig(BaseModel):
"project_id": "${env.NVIDIA_PROJECT_ID:test-project}", "project_id": "${env.NVIDIA_PROJECT_ID:test-project}",
"customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}", "customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}",
} }
class SFTLoRADefaultConfig(BaseModel):
"""NVIDIA-specific training configuration with default values."""
# ToDo: split into SFT and LoRA configs??
# General training parameters
n_epochs: int = 50
# NeMo customizer specific parameters
log_every_n_steps: Optional[int] = None
val_check_interval: float = 0.25
sequence_packing_enabled: bool = False
weight_decay: float = 0.01
lr: float = 0.0001
# SFT specific parameters
hidden_dropout: Optional[float] = None
attention_dropout: Optional[float] = None
ffn_dropout: Optional[float] = None
# LoRA default parameters
lora_adapter_dim: int = 8
lora_adapter_dropout: Optional[float] = None
lora_alpha: int = 16
# Data config
batch_size: int = 8
@classmethod
def sample_config(cls) -> Dict[str, Any]:
"""Return a sample configuration for NVIDIA training."""
return {
"n_epochs": 50,
"log_every_n_steps": 10,
"val_check_interval": 0.25,
"sequence_packing_enabled": False,
"weight_decay": 0.01,
"hidden_dropout": 0.1,
"attention_dropout": 0.1,
"lora_adapter_dim": 8,
"lora_alpha": 16,
"data_config": {
"dataset_id": "default",
"batch_size": 8,
},
"optimizer_config": {
"lr": 0.0001,
},
}

View file

@ -20,9 +20,8 @@ from llama_stack.apis.post_training import (
PostTrainingJobStatusResponse, PostTrainingJobStatusResponse,
TrainingConfig, TrainingConfig,
) )
from llama_stack.providers.remote.post_training.nvidia.config import ( from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostTrainingConfig
NvidiaPostTrainingConfig, from llama_stack.providers.remote.post_training.nvidia.utils import warn_unsupported_params
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .models import _MODEL_ENTRIES from .models import _MODEL_ENTRIES
@ -106,6 +105,11 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper):
) -> ListNvidiaPostTrainingJobs: ) -> ListNvidiaPostTrainingJobs:
"""Get all customization jobs. """Get all customization jobs.
Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs. Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs.
Returns a ListNvidiaPostTrainingJobs object with the following fields:
- data: List[NvidiaPostTrainingJob] - List of NvidiaPostTrainingJob objects
ToDo: Support for schema input for filtering.
""" """
params = {"page": page, "page_size": page_size, "sort": sort} params = {"page": page, "page_size": page_size, "sort": sort}
@ -137,6 +141,22 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper):
async def get_training_job_status(self, job_uuid: str) -> Optional[NvidiaPostTrainingJob]: async def get_training_job_status(self, job_uuid: str) -> Optional[NvidiaPostTrainingJob]:
"""Get the status of a customization job. """Get the status of a customization job.
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob. Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
Returns a NvidiaPostTrainingJob object with the following fields:
- job_uuid: str - Unique identifier for the job
- status: JobStatus - Current status of the job (in_progress, completed, failed, cancelled, scheduled)
- created_at: datetime - The time when the job was created
- updated_at: datetime - The last time the job status was updated
Additional fields that may be included:
- steps_completed: Optional[int] - Number of training steps completed
- epochs_completed: Optional[int] - Number of epochs completed
- percentage_done: Optional[float] - Percentage of training completed (0-100)
- best_epoch: Optional[int] - The epoch with the best performance
- train_loss: Optional[float] - Training loss of the best checkpoint
- val_loss: Optional[float] - Validation loss of the best checkpoint
- metrics: Optional[Dict] - Additional training metrics
- status_logs: Optional[List] - Detailed logs of status changes
""" """
response = await self._make_request( response = await self._make_request(
"GET", "GET",
@ -156,23 +176,20 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper):
) )
async def cancel_training_job(self, job_uuid: str) -> None: async def cancel_training_job(self, job_uuid: str) -> None:
"""Cancels a customization job."""
await self._make_request( await self._make_request(
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid} method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
) )
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
"""Get artifacts for a specific training job."""
raise NotImplementedError("Job artifacts are not implemented yet") raise NotImplementedError("Job artifacts are not implemented yet")
async def get_post_training_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: async def get_post_training_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
"""Get all post-training artifacts."""
raise NotImplementedError("Job artifacts are not implemented yet") raise NotImplementedError("Job artifacts are not implemented yet")
async def supervised_fine_tune( async def supervised_fine_tune(
self, self,
job_uuid: str, job_uuid: str,
training_config: TrainingConfig, training_config: Dict[str, Any],
hyperparam_search_config: Dict[str, Any], hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
model: str, model: str,
@ -195,15 +212,22 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper):
training_config: TrainingConfig - Configuration for training training_config: TrainingConfig - Configuration for training
model: str - Model identifier model: str - Model identifier
algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration
checkpoint_dir: Optional[str] - Directory containing model checkpoints checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm
job_uuid: str - Unique identifier for the job job_uuid: str - Unique identifier for the job, ignored atm
hyperparam_search_config: Dict[str, Any] - Configuration for hyperparameter search hyperparam_search_config: Dict[str, Any] - Configuration for hyperparameter search, ignored atm
logger_config: Dict[str, Any] - Configuration for logging logger_config: Dict[str, Any] - Configuration for logging, ignored atm
Environment Variables: Environment Variables:
- NVIDIA_PROJECT_ID: ID of the project - NVIDIA_API_KEY: str - API key for the NVIDIA API
- NVIDIA_DATASET_NAMESPACE: Namespace of the dataset Default: None
- NVIDIA_OUTPUT_MODEL_DIR: Directory to save the output model - NVIDIA_DATASET_NAMESPACE: str - Namespace of the dataset
Default: "default"
- NVIDIA_CUSTOMIZER_URL: str - URL of the NeMo Customizer API
Default: "http://nemo.test"
- NVIDIA_PROJECT_ID: str - ID of the project
Default: "test-project"
- NVIDIA_OUTPUT_MODEL_DIR: str - Directory to save the output model
Default: "test-example-model@v1"
Supported models: Supported models:
- meta/llama-3.1-8b-instruct - meta/llama-3.1-8b-instruct
@ -213,57 +237,100 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper):
Supported Parameters: Supported Parameters:
- TrainingConfig: - TrainingConfig:
- n_epochs - n_epochs: int - Number of epochs to train
- data_config Default: 50
- optimizer_config - data_config: DataConfig - Configuration for the dataset
- dtype - optimizer_config: OptimizerConfig - Configuration for the optimizer
- efficiency_config - dtype: str - Data type for training
- max_steps_per_epoch not supported (users are informed via warnings)
- efficiency_config: EfficiencyConfig - Configuration for efficiency
not supported
- max_steps_per_epoch: int - Maximum number of steps per epoch
Default: 1000
## NeMo customizer specific parameters
- log_every_n_steps: int - Log every n steps
Default: None
- val_check_interval: float - Validation check interval
Default: 0.25
- sequence_packing_enabled: bool - Sequence packing enabled
Default: False
## NeMo customizer specific SFT parameters
- hidden_dropout: float - Hidden dropout
Default: None (0.0-1.0)
- attention_dropout: float - Attention dropout
Default: None (0.0-1.0)
- ffn_dropout: float - FFN dropout
Default: None (0.0-1.0)
- DataConfig: - DataConfig:
- dataset_id - dataset_id: str - Dataset ID
- batch_size - batch_size: int - Batch size
Default: 8
- OptimizerConfig: - OptimizerConfig:
- lr - lr: float - Learning rate
- weight_decay Default: 0.0001
## NeMo customizer specific parameter
- weight_decay: float - Weight decay
Default: 0.01
- LoRA config: - LoRA config:
- adapter_dim ## NeMo customizer specific LoRA parameters
- adapter_dropout - adapter_dim: int - Adapter dimension
Default: 8 (supports powers of 2)
- adapter_dropout: float - Adapter dropout
Default: None (0.0-1.0)
- alpha: int - Scaling factor for the LoRA update
Default: 16
Note: Note:
- checkpoint_dir, hyperparam_search_config, logger_config are not supported atm, will be ignored and users are informed via warnings. - checkpoint_dir, hyperparam_search_config, logger_config are not supported (users are informed via warnings)
- Some parameters from TrainingConfig, DataConfig, OptimizerConfig are not supported atm, will be ignored and users are informed via warnings. - Some parameters from TrainingConfig, DataConfig, OptimizerConfig are not supported (users are informed via warnings)
User is informed about unsupported parameters via warnings. User is informed about unsupported parameters via warnings.
""" """
# map model to nvidia model name # Map model to nvidia model name
# ToDo: only supports llama-3.1-8b-instruct now, need to update this to support other models
nvidia_model = self.get_provider_model_id(model) nvidia_model = self.get_provider_model_id(model)
# Check the extra parameters # Check for unsupported method parameters
print(hyperparam_search_config, extra_json, params, headers, kwargs) unsupported_method_params = []
if checkpoint_dir:
unsupported_method_params.append(f"checkpoint_dir={checkpoint_dir}")
if hyperparam_search_config:
unsupported_method_params.append("hyperparam_search_config")
if logger_config:
unsupported_method_params.append("logger_config")
# Check for unsupported parameters if unsupported_method_params:
if checkpoint_dir or hyperparam_search_config or logger_config: warnings.warn(f"Parameters: {', '.join(unsupported_method_params)} are not supported and will be ignored")
warnings.warn(
"Parameters: {} not supported atm, will be ignored".format(
checkpoint_dir,
)
)
def warn_unsupported_params(config_dict: Dict[str, Any], supported_keys: List[str], config_name: str) -> None: # Define all supported parameters
"""Helper function to warn about unsupported parameters in a config dictionary.""" supported_params = {
unsupported_params = [k for k in config_dict.keys() if k not in supported_keys] "training_config": {
if unsupported_params: "n_epochs",
warnings.warn(f"Parameters: {unsupported_params} in `{config_name}` not supported and will be ignored.") "data_config",
"optimizer_config",
"log_every_n_steps",
"val_check_interval",
"sequence_packing_enabled",
"hidden_dropout",
"attention_dropout",
"ffn_dropout",
},
"data_config": {"dataset_id", "batch_size"},
"optimizer_config": {"lr", "weight_decay"},
"lora_config": {"type", "adapter_dim", "adapter_dropout", "alpha"},
}
# Check for unsupported parameters # Validate all parameters at once
warn_unsupported_params(training_config, ["n_epochs", "data_config", "optimizer_config"], "TrainingConfig") warn_unsupported_params(training_config, supported_params["training_config"], "TrainingConfig")
warn_unsupported_params(training_config["data_config"], ["dataset_id", "batch_size"], "DataConfig") warn_unsupported_params(training_config["data_config"], supported_params["data_config"], "DataConfig")
warn_unsupported_params(training_config["optimizer_config"], ["lr"], "OptimizerConfig") warn_unsupported_params(
training_config["optimizer_config"], supported_params["optimizer_config"], "OptimizerConfig"
)
output_model = self.config.output_model_dir output_model = self.config.output_model_dir
if output_model == "default":
warnings.warn("output_model_dir set via default value, will be ignored")
# Prepare base job configuration # Prepare base job configuration
job_config = { job_config = {
"config": nvidia_model, "config": nvidia_model,
@ -274,9 +341,19 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper):
"hyperparameters": { "hyperparameters": {
"training_type": "sft", "training_type": "sft",
"finetuning_type": "lora", "finetuning_type": "lora",
"epochs": training_config.get("n_epochs", 1), **{
"batch_size": training_config["data_config"].get("batch_size", 8), k: v
"learning_rate": training_config["optimizer_config"].get("lr", 0.0001), for k, v in {
"epochs": training_config.get("n_epochs"),
"batch_size": training_config["data_config"].get("batch_size"),
"learning_rate": training_config["optimizer_config"].get("lr"),
"weight_decay": training_config["optimizer_config"].get("weight_decay"),
"log_every_n_steps": training_config.get("log_every_n_steps"),
"val_check_interval": training_config.get("val_check_interval"),
"sequence_packing_enabled": training_config.get("sequence_packing_enabled"),
}.items()
if v is not None
},
}, },
"project": self.config.project_id, "project": self.config.project_id,
# TODO: ignored ownership, add it later # TODO: ignored ownership, add it later
@ -284,18 +361,37 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper):
"output_model": output_model, "output_model": output_model,
} }
# Handle SFT-specific optional parameters
job_config["hyperparameters"]["sft"] = {
k: v
for k, v in {
"ffn_dropout": training_config.get("ffn_dropout"),
"hidden_dropout": training_config.get("hidden_dropout"),
"attention_dropout": training_config.get("attention_dropout"),
}.items()
if v is not None
}
# Remove the sft dictionary if it's empty
if not job_config["hyperparameters"]["sft"]:
job_config["hyperparameters"].pop("sft")
# Handle LoRA-specific configuration # Handle LoRA-specific configuration
if algorithm_config: if algorithm_config:
if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA": if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA":
# Extract LoRA-specific parameters warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
lora_config = {k: v for k, v in algorithm_config.items() if k != "type"}
job_config["hyperparameters"]["lora"] = { job_config["hyperparameters"]["lora"] = {
"adapter_dim": lora_config.get("adapter_dim", 8), k: v
"adapter_dropout": lora_config.get("adapter_dropout", 1), for k, v in {
"adapter_dim": algorithm_config.get("adapter_dim"),
"alpha": algorithm_config.get("alpha"),
"adapter_dropout": algorithm_config.get("adapter_dropout"),
}.items()
if v is not None
} }
warn_unsupported_params(lora_config, ["adapter_dim", "adapter_dropout"], "LoRA config")
else: else:
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}") raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
# Create the customization job # Create the customization job
response = await self._make_request( response = await self._make_request(
method="POST", method="POST",
@ -305,12 +401,12 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper):
) )
job_uuid = response["id"] job_uuid = response["id"]
status = STATUS_MAPPING.get(response["status"].lower(), "unknown") response.pop("status")
created_at = datetime.fromisoformat(response["created_at"]) created_at = datetime.fromisoformat(response.pop("created_at"))
updated_at = datetime.fromisoformat(response["updated_at"]) updated_at = datetime.fromisoformat(response.pop("updated_at"))
return NvidiaPostTrainingJob( return NvidiaPostTrainingJob(
job_uuid=job_uuid, status=JobStatus(status), created_at=created_at, updated_at=updated_at, **response job_uuid=job_uuid, status=JobStatus.in_progress, created_at=created_at, updated_at=updated_at, **response
) )
async def preference_optimize( async def preference_optimize(
@ -326,5 +422,4 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper):
raise NotImplementedError("Preference optimization is not implemented yet") raise NotImplementedError("Preference optimization is not implemented yet")
async def get_training_job_container_logs(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: async def get_training_job_container_logs(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
"""Get the container logs of a customization job."""
raise NotImplementedError("Job logs are not implemented yet") raise NotImplementedError("Job logs are not implemented yet")

View file

@ -4,20 +4,54 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging import logging
from typing import Tuple import warnings
from typing import Any, Dict, Set, Tuple
from pydantic import BaseModel
from llama_stack.apis.post_training import TrainingConfig
from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefaultConfig
from .config import NvidiaPostTrainingConfig from .config import NvidiaPostTrainingConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def warn_unsupported_params(config_dict: Any, supported_keys: Set[str], config_name: str) -> None:
keys = set(config_dict.__annotations__.keys()) if isinstance(config_dict, BaseModel) else config_dict.keys()
unsupported_params = [k for k in keys if k not in supported_keys]
if unsupported_params:
warnings.warn(f"Parameters: {unsupported_params} in `{config_name}` not supported and will be ignored.")
def validate_training_params(
training_config: Dict[str, Any], supported_keys: Set[str], config_name: str = "TrainingConfig"
) -> None:
"""
Validates training parameters against supported keys.
Args:
training_config: Dictionary containing training configuration parameters
supported_keys: Set of supported parameter keys
config_name: Name of the configuration for warning messages
"""
sft_lora_fields = set(SFTLoRADefaultConfig.__annotations__.keys())
training_config_fields = set(TrainingConfig.__annotations__.keys())
# Check for not supported parameters:
# - not in either of configs
# - in TrainingConfig but not in SFTLoRADefaultConfig
unsupported_params = []
for key in training_config:
if isinstance(key, str) and key not in (supported_keys.union(sft_lora_fields)):
if key in (not sft_lora_fields or training_config_fields):
unsupported_params.append(key)
if unsupported_params:
warnings.warn(f"Parameters: {unsupported_params} in `{config_name}` are not supported and will be ignored.")
# ToDo: implement post health checks for customizer are enabled # ToDo: implement post health checks for customizer are enabled
async def _get_health(url: str) -> Tuple[bool, bool]: ... async def _get_health(url: str) -> Tuple[bool, bool]: ...

View file

@ -0,0 +1,201 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import unittest
import warnings
from unittest.mock import patch
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigEfficiencyConfig,
TrainingConfigOptimizerConfig,
)
class TestNvidiaParameters(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002"
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
_ = self.llama_stack_client.initialize()
self.make_request_patcher = patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
)
self.mock_make_request = self.make_request_patcher.start()
self.mock_make_request.return_value = {
"id": "job-123",
"status": "created",
"created_at": "2025-03-04T13:07:47.543605",
"updated_at": "2025-03-04T13:07:47.543605",
}
def tearDown(self):
self.make_request_patcher.stop()
def _assert_request_params(self, expected_json):
"""Helper method to verify parameters in the request JSON."""
call_args = self.mock_make_request.call_args
actual_json = call_args[1]["json"]
for key, value in expected_json.items():
if isinstance(value, dict):
for nested_key, nested_value in value.items():
assert actual_json[key][nested_key] == nested_value
else:
assert actual_json[key] == value
def test_optional_parameter_passed(self):
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
custom_adapter_dim = 32 # Different from default of 8
algorithm_config = LoraFinetuningConfig(
type="LoRA",
adapter_dim=custom_adapter_dim, # Custom value
adapter_dropout=0.2, # Custom value
)
data_config = TrainingConfigDataConfig(dataset_id="test-dataset", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0002)
training_config = TrainingConfig(
n_epochs=3,
data_config=data_config,
optimizer_config=optimizer_config,
)
self.llama_stack_client.post_training.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
self._assert_request_params(
{
"hyperparameters": {
"lora": {"adapter_dim": custom_adapter_dim, "adapter_dropout": 0.2},
"epochs": 3,
"learning_rate": 0.0002,
"batch_size": 16,
}
}
)
def test_required_parameter_passed(self):
"""Test scenario 2: When required parameters are passed."""
required_model = "meta-llama/Llama-3.1-8B-Instruct"
required_dataset_id = "required-dataset"
required_job_uuid = "required-job"
algorithm_config = LoraFinetuningConfig(type="LoRA", adapter_dim=8)
data_config = TrainingConfigDataConfig(
dataset_id=required_dataset_id, # Required parameter
batch_size=8,
)
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0001)
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
optimizer_config=optimizer_config,
)
self.llama_stack_client.post_training.supervised_fine_tune(
job_uuid=required_job_uuid, # Required parameter
model=required_model, # Required parameter
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
self.mock_make_request.assert_called_once()
call_args = self.mock_make_request.call_args
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct"
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
def test_unsupported_parameters_warning(self):
"""Test that warnings are raised for unsupported parameters."""
# Create a training config with unsupported parameters
data_config = TrainingConfigDataConfig(
dataset_id="test-dataset",
batch_size=8,
# Unsupported parameters
shuffle=True,
data_format="instruct",
validation_dataset_id="val-dataset",
)
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
weight_decay=0.01,
# Unsupported parameters
optimizer_type="adam",
num_warmup_steps=100,
)
efficiency_config = TrainingConfigEfficiencyConfig(
enable_activation_checkpointing=True # Unsupported parameter
)
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
optimizer_config=optimizer_config,
# Unsupported parameters
efficiency_config=efficiency_config,
max_steps_per_epoch=1000,
gradient_accumulation_steps=4,
max_validation_steps=100,
dtype="bf16",
)
# Capture warnings
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.llama_stack_client.post_training.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="test-dir", # Unsupported parameter
algorithm_config=LoraFinetuningConfig(type="LoRA"),
training_config=training_config,
logger_config={"test": "value"}, # Unsupported parameter
hyperparam_search_config={"test": "value"}, # Unsupported parameter
)
assert len(w) >= 4
warning_texts = [str(warning.message) for warning in w]
fields = [
"checkpoint_dir",
"hyperparam_search_config",
"logger_config",
"TrainingConfig",
"DataConfig",
"OptimizerConfig",
"max_steps_per_epoch",
"gradient_accumulation_steps",
"max_validation_steps",
"dtype",
]
for field in fields:
assert any(field in text for text in warning_texts)
if __name__ == "__main__":
unittest.main()