From 87ce96c1f74bae09734a4d43894b47dfa93a3b50 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 19 Mar 2025 18:57:16 +0000 Subject: [PATCH] parameter validation, test cases --- .../remote/post_training/nvidia/config.py | 53 +++++ .../post_training/nvidia/post_training.py | 221 +++++++++++++----- .../remote/post_training/nvidia/utils.py | 48 +++- .../unit/providers/nvidia/test_parameters.py | 201 ++++++++++++++++ 4 files changed, 453 insertions(+), 70 deletions(-) create mode 100644 tests/unit/providers/nvidia/test_parameters.py diff --git a/llama_stack/providers/remote/post_training/nvidia/config.py b/llama_stack/providers/remote/post_training/nvidia/config.py index b142cf026..7b42c8bb0 100644 --- a/llama_stack/providers/remote/post_training/nvidia/config.py +++ b/llama_stack/providers/remote/post_training/nvidia/config.py @@ -9,6 +9,8 @@ from typing import Any, Dict, Optional from pydantic import BaseModel, Field +# TODO: add default values for all fields + class NvidiaPostTrainingConfig(BaseModel): """Configuration for NVIDIA Post Training implementation.""" @@ -58,3 +60,54 @@ class NvidiaPostTrainingConfig(BaseModel): "project_id": "${env.NVIDIA_PROJECT_ID:test-project}", "customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}", } + + +class SFTLoRADefaultConfig(BaseModel): + """NVIDIA-specific training configuration with default values.""" + + # ToDo: split into SFT and LoRA configs?? + + # General training parameters + n_epochs: int = 50 + + # NeMo customizer specific parameters + log_every_n_steps: Optional[int] = None + val_check_interval: float = 0.25 + sequence_packing_enabled: bool = False + weight_decay: float = 0.01 + lr: float = 0.0001 + + # SFT specific parameters + hidden_dropout: Optional[float] = None + attention_dropout: Optional[float] = None + ffn_dropout: Optional[float] = None + + # LoRA default parameters + lora_adapter_dim: int = 8 + lora_adapter_dropout: Optional[float] = None + lora_alpha: int = 16 + + # Data config + batch_size: int = 8 + + @classmethod + def sample_config(cls) -> Dict[str, Any]: + """Return a sample configuration for NVIDIA training.""" + return { + "n_epochs": 50, + "log_every_n_steps": 10, + "val_check_interval": 0.25, + "sequence_packing_enabled": False, + "weight_decay": 0.01, + "hidden_dropout": 0.1, + "attention_dropout": 0.1, + "lora_adapter_dim": 8, + "lora_alpha": 16, + "data_config": { + "dataset_id": "default", + "batch_size": 8, + }, + "optimizer_config": { + "lr": 0.0001, + }, + } diff --git a/llama_stack/providers/remote/post_training/nvidia/post_training.py b/llama_stack/providers/remote/post_training/nvidia/post_training.py index c494afc03..2468e1d44 100644 --- a/llama_stack/providers/remote/post_training/nvidia/post_training.py +++ b/llama_stack/providers/remote/post_training/nvidia/post_training.py @@ -20,9 +20,8 @@ from llama_stack.apis.post_training import ( PostTrainingJobStatusResponse, TrainingConfig, ) -from llama_stack.providers.remote.post_training.nvidia.config import ( - NvidiaPostTrainingConfig, -) +from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostTrainingConfig +from llama_stack.providers.remote.post_training.nvidia.utils import warn_unsupported_params from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from .models import _MODEL_ENTRIES @@ -106,6 +105,11 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper): ) -> ListNvidiaPostTrainingJobs: """Get all customization jobs. Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs. + + Returns a ListNvidiaPostTrainingJobs object with the following fields: + - data: List[NvidiaPostTrainingJob] - List of NvidiaPostTrainingJob objects + + ToDo: Support for schema input for filtering. """ params = {"page": page, "page_size": page_size, "sort": sort} @@ -137,6 +141,22 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper): async def get_training_job_status(self, job_uuid: str) -> Optional[NvidiaPostTrainingJob]: """Get the status of a customization job. Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob. + + Returns a NvidiaPostTrainingJob object with the following fields: + - job_uuid: str - Unique identifier for the job + - status: JobStatus - Current status of the job (in_progress, completed, failed, cancelled, scheduled) + - created_at: datetime - The time when the job was created + - updated_at: datetime - The last time the job status was updated + + Additional fields that may be included: + - steps_completed: Optional[int] - Number of training steps completed + - epochs_completed: Optional[int] - Number of epochs completed + - percentage_done: Optional[float] - Percentage of training completed (0-100) + - best_epoch: Optional[int] - The epoch with the best performance + - train_loss: Optional[float] - Training loss of the best checkpoint + - val_loss: Optional[float] - Validation loss of the best checkpoint + - metrics: Optional[Dict] - Additional training metrics + - status_logs: Optional[List] - Detailed logs of status changes """ response = await self._make_request( "GET", @@ -156,23 +176,20 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper): ) async def cancel_training_job(self, job_uuid: str) -> None: - """Cancels a customization job.""" await self._make_request( method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid} ) async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: - """Get artifacts for a specific training job.""" raise NotImplementedError("Job artifacts are not implemented yet") async def get_post_training_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: - """Get all post-training artifacts.""" raise NotImplementedError("Job artifacts are not implemented yet") async def supervised_fine_tune( self, job_uuid: str, - training_config: TrainingConfig, + training_config: Dict[str, Any], hyperparam_search_config: Dict[str, Any], logger_config: Dict[str, Any], model: str, @@ -195,15 +212,22 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper): training_config: TrainingConfig - Configuration for training model: str - Model identifier algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration - checkpoint_dir: Optional[str] - Directory containing model checkpoints - job_uuid: str - Unique identifier for the job - hyperparam_search_config: Dict[str, Any] - Configuration for hyperparameter search - logger_config: Dict[str, Any] - Configuration for logging + checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm + job_uuid: str - Unique identifier for the job, ignored atm + hyperparam_search_config: Dict[str, Any] - Configuration for hyperparameter search, ignored atm + logger_config: Dict[str, Any] - Configuration for logging, ignored atm Environment Variables: - - NVIDIA_PROJECT_ID: ID of the project - - NVIDIA_DATASET_NAMESPACE: Namespace of the dataset - - NVIDIA_OUTPUT_MODEL_DIR: Directory to save the output model + - NVIDIA_API_KEY: str - API key for the NVIDIA API + Default: None + - NVIDIA_DATASET_NAMESPACE: str - Namespace of the dataset + Default: "default" + - NVIDIA_CUSTOMIZER_URL: str - URL of the NeMo Customizer API + Default: "http://nemo.test" + - NVIDIA_PROJECT_ID: str - ID of the project + Default: "test-project" + - NVIDIA_OUTPUT_MODEL_DIR: str - Directory to save the output model + Default: "test-example-model@v1" Supported models: - meta/llama-3.1-8b-instruct @@ -213,57 +237,100 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper): Supported Parameters: - TrainingConfig: - - n_epochs - - data_config - - optimizer_config - - dtype - - efficiency_config - - max_steps_per_epoch + - n_epochs: int - Number of epochs to train + Default: 50 + - data_config: DataConfig - Configuration for the dataset + - optimizer_config: OptimizerConfig - Configuration for the optimizer + - dtype: str - Data type for training + not supported (users are informed via warnings) + - efficiency_config: EfficiencyConfig - Configuration for efficiency + not supported + - max_steps_per_epoch: int - Maximum number of steps per epoch + Default: 1000 + ## NeMo customizer specific parameters + - log_every_n_steps: int - Log every n steps + Default: None + - val_check_interval: float - Validation check interval + Default: 0.25 + - sequence_packing_enabled: bool - Sequence packing enabled + Default: False + ## NeMo customizer specific SFT parameters + - hidden_dropout: float - Hidden dropout + Default: None (0.0-1.0) + - attention_dropout: float - Attention dropout + Default: None (0.0-1.0) + - ffn_dropout: float - FFN dropout + Default: None (0.0-1.0) + - DataConfig: - - dataset_id - - batch_size + - dataset_id: str - Dataset ID + - batch_size: int - Batch size + Default: 8 + - OptimizerConfig: - - lr - - weight_decay + - lr: float - Learning rate + Default: 0.0001 + ## NeMo customizer specific parameter + - weight_decay: float - Weight decay + Default: 0.01 + - LoRA config: - - adapter_dim - - adapter_dropout + ## NeMo customizer specific LoRA parameters + - adapter_dim: int - Adapter dimension + Default: 8 (supports powers of 2) + - adapter_dropout: float - Adapter dropout + Default: None (0.0-1.0) + - alpha: int - Scaling factor for the LoRA update + Default: 16 Note: - - checkpoint_dir, hyperparam_search_config, logger_config are not supported atm, will be ignored and users are informed via warnings. - - Some parameters from TrainingConfig, DataConfig, OptimizerConfig are not supported atm, will be ignored and users are informed via warnings. + - checkpoint_dir, hyperparam_search_config, logger_config are not supported (users are informed via warnings) + - Some parameters from TrainingConfig, DataConfig, OptimizerConfig are not supported (users are informed via warnings) User is informed about unsupported parameters via warnings. """ - # map model to nvidia model name + # Map model to nvidia model name + # ToDo: only supports llama-3.1-8b-instruct now, need to update this to support other models nvidia_model = self.get_provider_model_id(model) - # Check the extra parameters - print(hyperparam_search_config, extra_json, params, headers, kwargs) + # Check for unsupported method parameters + unsupported_method_params = [] + if checkpoint_dir: + unsupported_method_params.append(f"checkpoint_dir={checkpoint_dir}") + if hyperparam_search_config: + unsupported_method_params.append("hyperparam_search_config") + if logger_config: + unsupported_method_params.append("logger_config") - # Check for unsupported parameters - if checkpoint_dir or hyperparam_search_config or logger_config: - warnings.warn( - "Parameters: {} not supported atm, will be ignored".format( - checkpoint_dir, - ) - ) + if unsupported_method_params: + warnings.warn(f"Parameters: {', '.join(unsupported_method_params)} are not supported and will be ignored") - def warn_unsupported_params(config_dict: Dict[str, Any], supported_keys: List[str], config_name: str) -> None: - """Helper function to warn about unsupported parameters in a config dictionary.""" - unsupported_params = [k for k in config_dict.keys() if k not in supported_keys] - if unsupported_params: - warnings.warn(f"Parameters: {unsupported_params} in `{config_name}` not supported and will be ignored.") + # Define all supported parameters + supported_params = { + "training_config": { + "n_epochs", + "data_config", + "optimizer_config", + "log_every_n_steps", + "val_check_interval", + "sequence_packing_enabled", + "hidden_dropout", + "attention_dropout", + "ffn_dropout", + }, + "data_config": {"dataset_id", "batch_size"}, + "optimizer_config": {"lr", "weight_decay"}, + "lora_config": {"type", "adapter_dim", "adapter_dropout", "alpha"}, + } - # Check for unsupported parameters - warn_unsupported_params(training_config, ["n_epochs", "data_config", "optimizer_config"], "TrainingConfig") - warn_unsupported_params(training_config["data_config"], ["dataset_id", "batch_size"], "DataConfig") - warn_unsupported_params(training_config["optimizer_config"], ["lr"], "OptimizerConfig") + # Validate all parameters at once + warn_unsupported_params(training_config, supported_params["training_config"], "TrainingConfig") + warn_unsupported_params(training_config["data_config"], supported_params["data_config"], "DataConfig") + warn_unsupported_params( + training_config["optimizer_config"], supported_params["optimizer_config"], "OptimizerConfig" + ) output_model = self.config.output_model_dir - if output_model == "default": - warnings.warn("output_model_dir set via default value, will be ignored") - # Prepare base job configuration job_config = { "config": nvidia_model, @@ -274,9 +341,19 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper): "hyperparameters": { "training_type": "sft", "finetuning_type": "lora", - "epochs": training_config.get("n_epochs", 1), - "batch_size": training_config["data_config"].get("batch_size", 8), - "learning_rate": training_config["optimizer_config"].get("lr", 0.0001), + **{ + k: v + for k, v in { + "epochs": training_config.get("n_epochs"), + "batch_size": training_config["data_config"].get("batch_size"), + "learning_rate": training_config["optimizer_config"].get("lr"), + "weight_decay": training_config["optimizer_config"].get("weight_decay"), + "log_every_n_steps": training_config.get("log_every_n_steps"), + "val_check_interval": training_config.get("val_check_interval"), + "sequence_packing_enabled": training_config.get("sequence_packing_enabled"), + }.items() + if v is not None + }, }, "project": self.config.project_id, # TODO: ignored ownership, add it later @@ -284,18 +361,37 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper): "output_model": output_model, } + # Handle SFT-specific optional parameters + job_config["hyperparameters"]["sft"] = { + k: v + for k, v in { + "ffn_dropout": training_config.get("ffn_dropout"), + "hidden_dropout": training_config.get("hidden_dropout"), + "attention_dropout": training_config.get("attention_dropout"), + }.items() + if v is not None + } + + # Remove the sft dictionary if it's empty + if not job_config["hyperparameters"]["sft"]: + job_config["hyperparameters"].pop("sft") + # Handle LoRA-specific configuration if algorithm_config: if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA": - # Extract LoRA-specific parameters - lora_config = {k: v for k, v in algorithm_config.items() if k != "type"} + warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config") job_config["hyperparameters"]["lora"] = { - "adapter_dim": lora_config.get("adapter_dim", 8), - "adapter_dropout": lora_config.get("adapter_dropout", 1), + k: v + for k, v in { + "adapter_dim": algorithm_config.get("adapter_dim"), + "alpha": algorithm_config.get("alpha"), + "adapter_dropout": algorithm_config.get("adapter_dropout"), + }.items() + if v is not None } - warn_unsupported_params(lora_config, ["adapter_dim", "adapter_dropout"], "LoRA config") else: raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}") + # Create the customization job response = await self._make_request( method="POST", @@ -305,12 +401,12 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper): ) job_uuid = response["id"] - status = STATUS_MAPPING.get(response["status"].lower(), "unknown") - created_at = datetime.fromisoformat(response["created_at"]) - updated_at = datetime.fromisoformat(response["updated_at"]) + response.pop("status") + created_at = datetime.fromisoformat(response.pop("created_at")) + updated_at = datetime.fromisoformat(response.pop("updated_at")) return NvidiaPostTrainingJob( - job_uuid=job_uuid, status=JobStatus(status), created_at=created_at, updated_at=updated_at, **response + job_uuid=job_uuid, status=JobStatus.in_progress, created_at=created_at, updated_at=updated_at, **response ) async def preference_optimize( @@ -326,5 +422,4 @@ class NvidiaPostTrainingAdapter(PostTraining, ModelRegistryHelper): raise NotImplementedError("Preference optimization is not implemented yet") async def get_training_job_container_logs(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: - """Get the container logs of a customization job.""" raise NotImplementedError("Job logs are not implemented yet") diff --git a/llama_stack/providers/remote/post_training/nvidia/utils.py b/llama_stack/providers/remote/post_training/nvidia/utils.py index a8d2a9cbc..383df9c2c 100644 --- a/llama_stack/providers/remote/post_training/nvidia/utils.py +++ b/llama_stack/providers/remote/post_training/nvidia/utils.py @@ -4,20 +4,54 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, IAny, nc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - import logging -from typing import Tuple +import warnings +from typing import Any, Dict, Set, Tuple + +from pydantic import BaseModel + +from llama_stack.apis.post_training import TrainingConfig +from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefaultConfig from .config import NvidiaPostTrainingConfig logger = logging.getLogger(__name__) +def warn_unsupported_params(config_dict: Any, supported_keys: Set[str], config_name: str) -> None: + keys = set(config_dict.__annotations__.keys()) if isinstance(config_dict, BaseModel) else config_dict.keys() + unsupported_params = [k for k in keys if k not in supported_keys] + if unsupported_params: + warnings.warn(f"Parameters: {unsupported_params} in `{config_name}` not supported and will be ignored.") + + +def validate_training_params( + training_config: Dict[str, Any], supported_keys: Set[str], config_name: str = "TrainingConfig" +) -> None: + """ + Validates training parameters against supported keys. + + Args: + training_config: Dictionary containing training configuration parameters + supported_keys: Set of supported parameter keys + config_name: Name of the configuration for warning messages + """ + sft_lora_fields = set(SFTLoRADefaultConfig.__annotations__.keys()) + training_config_fields = set(TrainingConfig.__annotations__.keys()) + + # Check for not supported parameters: + # - not in either of configs + # - in TrainingConfig but not in SFTLoRADefaultConfig + unsupported_params = [] + for key in training_config: + if isinstance(key, str) and key not in (supported_keys.union(sft_lora_fields)): + if key in (not sft_lora_fields or training_config_fields): + unsupported_params.append(key) + + if unsupported_params: + warnings.warn(f"Parameters: {unsupported_params} in `{config_name}` are not supported and will be ignored.") + + # ToDo: implement post health checks for customizer are enabled async def _get_health(url: str) -> Tuple[bool, bool]: ... diff --git a/tests/unit/providers/nvidia/test_parameters.py b/tests/unit/providers/nvidia/test_parameters.py new file mode 100644 index 000000000..db95a03c0 --- /dev/null +++ b/tests/unit/providers/nvidia/test_parameters.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import unittest +import warnings +from unittest.mock import patch + +from llama_stack.distribution.library_client import LlamaStackAsLibraryClient +from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig +from llama_stack_client.types.post_training_supervised_fine_tune_params import ( + TrainingConfig, + TrainingConfigDataConfig, + TrainingConfigEfficiencyConfig, + TrainingConfigOptimizerConfig, +) + + +class TestNvidiaParameters(unittest.TestCase): + def setUp(self): + os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" + os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" + os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002" + + self.llama_stack_client = LlamaStackAsLibraryClient("nvidia") + _ = self.llama_stack_client.initialize() + + self.make_request_patcher = patch( + "llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request" + ) + self.mock_make_request = self.make_request_patcher.start() + self.mock_make_request.return_value = { + "id": "job-123", + "status": "created", + "created_at": "2025-03-04T13:07:47.543605", + "updated_at": "2025-03-04T13:07:47.543605", + } + + def tearDown(self): + self.make_request_patcher.stop() + + def _assert_request_params(self, expected_json): + """Helper method to verify parameters in the request JSON.""" + call_args = self.mock_make_request.call_args + actual_json = call_args[1]["json"] + + for key, value in expected_json.items(): + if isinstance(value, dict): + for nested_key, nested_value in value.items(): + assert actual_json[key][nested_key] == nested_value + else: + assert actual_json[key] == value + + def test_optional_parameter_passed(self): + """Test scenario 1: When an optional parameter is passed and value is correctly set.""" + custom_adapter_dim = 32 # Different from default of 8 + algorithm_config = LoraFinetuningConfig( + type="LoRA", + adapter_dim=custom_adapter_dim, # Custom value + adapter_dropout=0.2, # Custom value + ) + + data_config = TrainingConfigDataConfig(dataset_id="test-dataset", batch_size=16) + optimizer_config = TrainingConfigOptimizerConfig(lr=0.0002) + training_config = TrainingConfig( + n_epochs=3, + data_config=data_config, + optimizer_config=optimizer_config, + ) + + self.llama_stack_client.post_training.supervised_fine_tune( + job_uuid="test-job", + model="meta-llama/Llama-3.1-8B-Instruct", + checkpoint_dir="", + algorithm_config=algorithm_config, + training_config=training_config, + logger_config={}, + hyperparam_search_config={}, + ) + + self._assert_request_params( + { + "hyperparameters": { + "lora": {"adapter_dim": custom_adapter_dim, "adapter_dropout": 0.2}, + "epochs": 3, + "learning_rate": 0.0002, + "batch_size": 16, + } + } + ) + + def test_required_parameter_passed(self): + """Test scenario 2: When required parameters are passed.""" + required_model = "meta-llama/Llama-3.1-8B-Instruct" + required_dataset_id = "required-dataset" + required_job_uuid = "required-job" + + algorithm_config = LoraFinetuningConfig(type="LoRA", adapter_dim=8) + + data_config = TrainingConfigDataConfig( + dataset_id=required_dataset_id, # Required parameter + batch_size=8, + ) + + optimizer_config = TrainingConfigOptimizerConfig(lr=0.0001) + + training_config = TrainingConfig( + n_epochs=1, + data_config=data_config, + optimizer_config=optimizer_config, + ) + + self.llama_stack_client.post_training.supervised_fine_tune( + job_uuid=required_job_uuid, # Required parameter + model=required_model, # Required parameter + checkpoint_dir="", + algorithm_config=algorithm_config, + training_config=training_config, + logger_config={}, + hyperparam_search_config={}, + ) + + self.mock_make_request.assert_called_once() + call_args = self.mock_make_request.call_args + + assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct" + assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id + + def test_unsupported_parameters_warning(self): + """Test that warnings are raised for unsupported parameters.""" + # Create a training config with unsupported parameters + data_config = TrainingConfigDataConfig( + dataset_id="test-dataset", + batch_size=8, + # Unsupported parameters + shuffle=True, + data_format="instruct", + validation_dataset_id="val-dataset", + ) + + optimizer_config = TrainingConfigOptimizerConfig( + lr=0.0001, + weight_decay=0.01, + # Unsupported parameters + optimizer_type="adam", + num_warmup_steps=100, + ) + + efficiency_config = TrainingConfigEfficiencyConfig( + enable_activation_checkpointing=True # Unsupported parameter + ) + + training_config = TrainingConfig( + n_epochs=1, + data_config=data_config, + optimizer_config=optimizer_config, + # Unsupported parameters + efficiency_config=efficiency_config, + max_steps_per_epoch=1000, + gradient_accumulation_steps=4, + max_validation_steps=100, + dtype="bf16", + ) + + # Capture warnings + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + self.llama_stack_client.post_training.supervised_fine_tune( + job_uuid="test-job", + model="meta-llama/Llama-3.1-8B-Instruct", + checkpoint_dir="test-dir", # Unsupported parameter + algorithm_config=LoraFinetuningConfig(type="LoRA"), + training_config=training_config, + logger_config={"test": "value"}, # Unsupported parameter + hyperparam_search_config={"test": "value"}, # Unsupported parameter + ) + assert len(w) >= 4 + warning_texts = [str(warning.message) for warning in w] + + fields = [ + "checkpoint_dir", + "hyperparam_search_config", + "logger_config", + "TrainingConfig", + "DataConfig", + "OptimizerConfig", + "max_steps_per_epoch", + "gradient_accumulation_steps", + "max_validation_steps", + "dtype", + ] + for field in fields: + assert any(field in text for text in warning_texts) + + +if __name__ == "__main__": + unittest.main()