mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
Merge c36628fefd
into 61582f327c
This commit is contained in:
commit
8311ee7457
3 changed files with 26 additions and 16 deletions
|
@ -72,3 +72,11 @@ class ModelTypeError(TypeError):
|
||||||
f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'"
|
f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'"
|
||||||
)
|
)
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class MissingTrainingConfigError(ValueError):
|
||||||
|
"""raise when Llama Stack is missing configuration for training"""
|
||||||
|
|
||||||
|
def __init__(self, config_name: str) -> None:
|
||||||
|
message = f"'{config_name}' is required for training"
|
||||||
|
super().__init__(message)
|
||||||
|
|
|
@ -20,6 +20,7 @@ from transformers import (
|
||||||
)
|
)
|
||||||
from trl import SFTConfig, SFTTrainer
|
from trl import SFTConfig, SFTTrainer
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import MissingTrainingConfigError
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.post_training import (
|
from llama_stack.apis.post_training import (
|
||||||
|
@ -224,8 +225,8 @@ class HFFinetuningSingleDevice:
|
||||||
tuple: (train_dataset, eval_dataset, tokenizer)
|
tuple: (train_dataset, eval_dataset, tokenizer)
|
||||||
"""
|
"""
|
||||||
# Validate data config
|
# Validate data config
|
||||||
if not config.data_config:
|
if config.data_config is None:
|
||||||
raise ValueError("DataConfig is required for training")
|
raise MissingTrainingConfigError("DataConfig")
|
||||||
|
|
||||||
# Load dataset
|
# Load dataset
|
||||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||||
|
@ -300,8 +301,8 @@ class HFFinetuningSingleDevice:
|
||||||
logger.info(f"Using custom learning rate: {lr}")
|
logger.info(f"Using custom learning rate: {lr}")
|
||||||
|
|
||||||
# Validate data config
|
# Validate data config
|
||||||
if not config.data_config:
|
if config.data_config is None:
|
||||||
raise ValueError("DataConfig is required for training")
|
raise MissingTrainingConfigError("DataConfig")
|
||||||
data_config = config.data_config
|
data_config = config.data_config
|
||||||
|
|
||||||
# Calculate steps and get save strategy
|
# Calculate steps and get save strategy
|
||||||
|
@ -392,8 +393,8 @@ class HFFinetuningSingleDevice:
|
||||||
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
|
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
|
||||||
|
|
||||||
# Calculate steps per epoch
|
# Calculate steps per epoch
|
||||||
if not config_obj.data_config:
|
if config_obj.data_config is None:
|
||||||
raise ValueError("DataConfig is required for training")
|
raise MissingTrainingConfigError("DataConfig")
|
||||||
steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size
|
steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size
|
||||||
|
|
||||||
# Setup training arguments
|
# Setup training arguments
|
||||||
|
@ -475,8 +476,8 @@ class HFFinetuningSingleDevice:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate data config
|
# Validate data config
|
||||||
if not config.data_config:
|
if config.data_config is None:
|
||||||
raise ValueError("DataConfig is required for training")
|
raise MissingTrainingConfigError("DataConfig")
|
||||||
|
|
||||||
# Train in a separate process
|
# Train in a separate process
|
||||||
logger.info("Starting training in separate process")
|
logger.info("Starting training in separate process")
|
||||||
|
|
|
@ -17,6 +17,7 @@ from transformers import (
|
||||||
)
|
)
|
||||||
from trl import DPOConfig, DPOTrainer
|
from trl import DPOConfig, DPOTrainer
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import MissingTrainingConfigError
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.post_training import (
|
from llama_stack.apis.post_training import (
|
||||||
|
@ -204,8 +205,8 @@ class HFDPOAlignmentSingleDevice:
|
||||||
) -> tuple[Dataset, Dataset, AutoTokenizer]:
|
) -> tuple[Dataset, Dataset, AutoTokenizer]:
|
||||||
"""Load and prepare the dataset for DPO training."""
|
"""Load and prepare the dataset for DPO training."""
|
||||||
# Validate data config
|
# Validate data config
|
||||||
if not config.data_config:
|
if config.data_config is None:
|
||||||
raise ValueError("DataConfig is required for DPO training")
|
raise MissingTrainingConfigError("DataConfig")
|
||||||
|
|
||||||
# Load dataset
|
# Load dataset
|
||||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||||
|
@ -266,8 +267,8 @@ class HFDPOAlignmentSingleDevice:
|
||||||
logger.info(f"Using custom learning rate: {lr}")
|
logger.info(f"Using custom learning rate: {lr}")
|
||||||
|
|
||||||
# Validate data config
|
# Validate data config
|
||||||
if not config.data_config:
|
if config.data_config is None:
|
||||||
raise ValueError("DataConfig is required for training")
|
raise MissingTrainingConfigError("DataConfig")
|
||||||
data_config = config.data_config
|
data_config = config.data_config
|
||||||
|
|
||||||
# Calculate steps and get save strategy
|
# Calculate steps and get save strategy
|
||||||
|
@ -356,8 +357,8 @@ class HFDPOAlignmentSingleDevice:
|
||||||
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
|
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
|
||||||
|
|
||||||
# Calculate steps per epoch
|
# Calculate steps per epoch
|
||||||
if not config_obj.data_config:
|
if config_obj.data_config is None:
|
||||||
raise ValueError("DataConfig is required for training")
|
raise MissingTrainingConfigError("DataConfig")
|
||||||
steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size
|
steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size
|
||||||
|
|
||||||
# Setup training arguments
|
# Setup training arguments
|
||||||
|
@ -441,8 +442,8 @@ class HFDPOAlignmentSingleDevice:
|
||||||
}
|
}
|
||||||
|
|
||||||
# Validate data config
|
# Validate data config
|
||||||
if not config.data_config:
|
if config.data_config is None:
|
||||||
raise ValueError("DataConfig is required for training")
|
raise MissingTrainingConfigError("DataConfig")
|
||||||
|
|
||||||
# Train in a separate process
|
# Train in a separate process
|
||||||
logger.info("Starting DPO training in separate process")
|
logger.info("Starting DPO training in separate process")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue