This commit is contained in:
Nathan Weinberg 2025-08-14 13:56:45 -04:00 committed by GitHub
commit 8311ee7457
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 26 additions and 16 deletions

View file

@ -72,3 +72,11 @@ class ModelTypeError(TypeError):
f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'"
)
super().__init__(message)
class MissingTrainingConfigError(ValueError):
"""raise when Llama Stack is missing configuration for training"""
def __init__(self, config_name: str) -> None:
message = f"'{config_name}' is required for training"
super().__init__(message)

View file

@ -20,6 +20,7 @@ from transformers import (
)
from trl import SFTConfig, SFTTrainer
from llama_stack.apis.common.errors import MissingTrainingConfigError
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
@ -224,8 +225,8 @@ class HFFinetuningSingleDevice:
tuple: (train_dataset, eval_dataset, tokenizer)
"""
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
if config.data_config is None:
raise MissingTrainingConfigError("DataConfig")
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
@ -300,8 +301,8 @@ class HFFinetuningSingleDevice:
logger.info(f"Using custom learning rate: {lr}")
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
if config.data_config is None:
raise MissingTrainingConfigError("DataConfig")
data_config = config.data_config
# Calculate steps and get save strategy
@ -392,8 +393,8 @@ class HFFinetuningSingleDevice:
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
# Calculate steps per epoch
if not config_obj.data_config:
raise ValueError("DataConfig is required for training")
if config_obj.data_config is None:
raise MissingTrainingConfigError("DataConfig")
steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size
# Setup training arguments
@ -475,8 +476,8 @@ class HFFinetuningSingleDevice:
)
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
if config.data_config is None:
raise MissingTrainingConfigError("DataConfig")
# Train in a separate process
logger.info("Starting training in separate process")

View file

@ -17,6 +17,7 @@ from transformers import (
)
from trl import DPOConfig, DPOTrainer
from llama_stack.apis.common.errors import MissingTrainingConfigError
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
@ -204,8 +205,8 @@ class HFDPOAlignmentSingleDevice:
) -> tuple[Dataset, Dataset, AutoTokenizer]:
"""Load and prepare the dataset for DPO training."""
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for DPO training")
if config.data_config is None:
raise MissingTrainingConfigError("DataConfig")
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
@ -266,8 +267,8 @@ class HFDPOAlignmentSingleDevice:
logger.info(f"Using custom learning rate: {lr}")
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
if config.data_config is None:
raise MissingTrainingConfigError("DataConfig")
data_config = config.data_config
# Calculate steps and get save strategy
@ -356,8 +357,8 @@ class HFDPOAlignmentSingleDevice:
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
# Calculate steps per epoch
if not config_obj.data_config:
raise ValueError("DataConfig is required for training")
if config_obj.data_config is None:
raise MissingTrainingConfigError("DataConfig")
steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size
# Setup training arguments
@ -441,8 +442,8 @@ class HFDPOAlignmentSingleDevice:
}
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
if config.data_config is None:
raise MissingTrainingConfigError("DataConfig")
# Train in a separate process
logger.info("Starting DPO training in separate process")