mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
Merge c36628fefd
into 61582f327c
This commit is contained in:
commit
8311ee7457
3 changed files with 26 additions and 16 deletions
|
@ -72,3 +72,11 @@ class ModelTypeError(TypeError):
|
|||
f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'"
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class MissingTrainingConfigError(ValueError):
|
||||
"""raise when Llama Stack is missing configuration for training"""
|
||||
|
||||
def __init__(self, config_name: str) -> None:
|
||||
message = f"'{config_name}' is required for training"
|
||||
super().__init__(message)
|
||||
|
|
|
@ -20,6 +20,7 @@ from transformers import (
|
|||
)
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
from llama_stack.apis.common.errors import MissingTrainingConfigError
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
|
@ -224,8 +225,8 @@ class HFFinetuningSingleDevice:
|
|||
tuple: (train_dataset, eval_dataset, tokenizer)
|
||||
"""
|
||||
# Validate data config
|
||||
if not config.data_config:
|
||||
raise ValueError("DataConfig is required for training")
|
||||
if config.data_config is None:
|
||||
raise MissingTrainingConfigError("DataConfig")
|
||||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||
|
@ -300,8 +301,8 @@ class HFFinetuningSingleDevice:
|
|||
logger.info(f"Using custom learning rate: {lr}")
|
||||
|
||||
# Validate data config
|
||||
if not config.data_config:
|
||||
raise ValueError("DataConfig is required for training")
|
||||
if config.data_config is None:
|
||||
raise MissingTrainingConfigError("DataConfig")
|
||||
data_config = config.data_config
|
||||
|
||||
# Calculate steps and get save strategy
|
||||
|
@ -392,8 +393,8 @@ class HFFinetuningSingleDevice:
|
|||
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
|
||||
|
||||
# Calculate steps per epoch
|
||||
if not config_obj.data_config:
|
||||
raise ValueError("DataConfig is required for training")
|
||||
if config_obj.data_config is None:
|
||||
raise MissingTrainingConfigError("DataConfig")
|
||||
steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size
|
||||
|
||||
# Setup training arguments
|
||||
|
@ -475,8 +476,8 @@ class HFFinetuningSingleDevice:
|
|||
)
|
||||
|
||||
# Validate data config
|
||||
if not config.data_config:
|
||||
raise ValueError("DataConfig is required for training")
|
||||
if config.data_config is None:
|
||||
raise MissingTrainingConfigError("DataConfig")
|
||||
|
||||
# Train in a separate process
|
||||
logger.info("Starting training in separate process")
|
||||
|
|
|
@ -17,6 +17,7 @@ from transformers import (
|
|||
)
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
from llama_stack.apis.common.errors import MissingTrainingConfigError
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
|
@ -204,8 +205,8 @@ class HFDPOAlignmentSingleDevice:
|
|||
) -> tuple[Dataset, Dataset, AutoTokenizer]:
|
||||
"""Load and prepare the dataset for DPO training."""
|
||||
# Validate data config
|
||||
if not config.data_config:
|
||||
raise ValueError("DataConfig is required for DPO training")
|
||||
if config.data_config is None:
|
||||
raise MissingTrainingConfigError("DataConfig")
|
||||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||
|
@ -266,8 +267,8 @@ class HFDPOAlignmentSingleDevice:
|
|||
logger.info(f"Using custom learning rate: {lr}")
|
||||
|
||||
# Validate data config
|
||||
if not config.data_config:
|
||||
raise ValueError("DataConfig is required for training")
|
||||
if config.data_config is None:
|
||||
raise MissingTrainingConfigError("DataConfig")
|
||||
data_config = config.data_config
|
||||
|
||||
# Calculate steps and get save strategy
|
||||
|
@ -356,8 +357,8 @@ class HFDPOAlignmentSingleDevice:
|
|||
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
|
||||
|
||||
# Calculate steps per epoch
|
||||
if not config_obj.data_config:
|
||||
raise ValueError("DataConfig is required for training")
|
||||
if config_obj.data_config is None:
|
||||
raise MissingTrainingConfigError("DataConfig")
|
||||
steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size
|
||||
|
||||
# Setup training arguments
|
||||
|
@ -441,8 +442,8 @@ class HFDPOAlignmentSingleDevice:
|
|||
}
|
||||
|
||||
# Validate data config
|
||||
if not config.data_config:
|
||||
raise ValueError("DataConfig is required for training")
|
||||
if config.data_config is None:
|
||||
raise MissingTrainingConfigError("DataConfig")
|
||||
|
||||
# Train in a separate process
|
||||
logger.info("Starting DPO training in separate process")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue