mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
feat: add huggingface post_training impl
adds an inline HF SFTTrainer provider. Alongside touchtune -- this is a super popular option for running training jobs. The config allows a user to specify some key fields such as a model, chat_template, device, etc the provider comes with one recipe `finetune_single_device` which works both with and without LoRA. any model that is a valid HF identifier can be given and the model will be pulled. this has been tested so far with CPU and MPS device types, but should be compatible with CUDA out of the box The provider processes the given dataset into the proper format, established the various steps per epoch, steps per save, steps per eval, sets a sane SFTConfig, and runs n_epochs of training if checkpoint_dir is none, no model is saved. If there is a checkpoint dir, a model is saved every `save_steps` and at the end of training. Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
65cf076f13
commit
6c3a40e3d2
5 changed files with 788 additions and 0 deletions
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import HuggingFacePostTrainingConfig
|
||||
|
||||
# post_training api and the huggingface provider is still experimental and under heavy development
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: HuggingFacePostTrainingConfig,
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .post_training import HuggingFacePostTrainingImpl
|
||||
|
||||
impl = HuggingFacePostTrainingImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
deps[Api.datasets],
|
||||
)
|
||||
return impl
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class HuggingFacePostTrainingConfig(BaseModel):
|
||||
# Device to run training on (cuda, cpu, mps)
|
||||
device: str = "cuda"
|
||||
|
||||
# Distributed training backend if using multiple devices
|
||||
# fsdp: Fully Sharded Data Parallel
|
||||
# deepspeed: DeepSpeed ZeRO optimization
|
||||
distributed_backend: Literal["fsdp", "deepspeed"] | None = None
|
||||
|
||||
# Format for saving model checkpoints
|
||||
# full_state: Save complete model state
|
||||
# huggingface: Save in HuggingFace format (recommended for compatibility)
|
||||
checkpoint_format: Literal["full_state", "huggingface"] | None = "huggingface"
|
||||
|
||||
# Template for formatting chat inputs and outputs
|
||||
# Used to structure the conversation format for training
|
||||
chat_template: str = "<|user|>\n{input}\n<|assistant|>\n{output}"
|
||||
|
||||
# Model-specific configuration parameters
|
||||
# trust_remote_code: Allow execution of custom model code
|
||||
# attn_implementation: Use SDPA (Scaled Dot Product Attention) for better performance
|
||||
model_specific_config: dict = {
|
||||
"trust_remote_code": True,
|
||||
"attn_implementation": "sdpa",
|
||||
}
|
||||
|
||||
# Maximum sequence length for training
|
||||
# Set to 2048 as this is the maximum that works reliably on MPS (Apple Silicon)
|
||||
# Longer sequences may cause memory issues on MPS devices
|
||||
max_seq_length: int = 2048
|
||||
|
||||
# Enable gradient checkpointing to reduce memory usage
|
||||
# Trades computation for memory by recomputing activations
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
# Maximum number of checkpoints to keep
|
||||
# Older checkpoints are deleted when this limit is reached
|
||||
save_total_limit: int = 3
|
||||
|
||||
# Number of training steps between logging updates
|
||||
logging_steps: int = 10
|
||||
|
||||
# Ratio of training steps used for learning rate warmup
|
||||
# Helps stabilize early training
|
||||
warmup_ratio: float = 0.1
|
||||
|
||||
# L2 regularization coefficient
|
||||
# Helps prevent overfitting
|
||||
weight_decay: float = 0.01
|
||||
|
||||
# Number of worker processes for data loading
|
||||
# Higher values can improve data loading speed but increase memory usage
|
||||
dataloader_num_workers: int = 4
|
||||
|
||||
# Whether to pin memory in data loader
|
||||
# Can improve data transfer speed to GPU but uses more memory
|
||||
dataloader_pin_memory: bool = True
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}
|
|
@ -0,0 +1,176 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
Checkpoint,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
ListPostTrainingJobsResponse,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface.config import (
|
||||
HuggingFacePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
||||
HFFinetuningSingleDevice,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
class TrainingArtifactType(Enum):
|
||||
CHECKPOINT = "checkpoint"
|
||||
RESOURCES_STATS = "resources_stats"
|
||||
|
||||
|
||||
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
|
||||
|
||||
|
||||
class HuggingFacePostTrainingImpl:
|
||||
def __init__(
|
||||
self,
|
||||
config: HuggingFacePostTrainingConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets: Datasets,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets
|
||||
self._scheduler = Scheduler()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await self._scheduler.shutdown()
|
||||
|
||||
@staticmethod
|
||||
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.CHECKPOINT.value,
|
||||
name=checkpoint.identifier,
|
||||
uri=checkpoint.path,
|
||||
metadata=dict(checkpoint),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
name=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
metadata=resources_stats,
|
||||
)
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: str | None = None,
|
||||
algorithm_config: AlgorithmConfig | None = None,
|
||||
) -> PostTrainingJob:
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
on_log_message_cb("Starting HF finetuning")
|
||||
|
||||
recipe = HFFinetuningSingleDevice(
|
||||
job_uuid=job_uuid,
|
||||
datasetio_api=self.datasetio_api,
|
||||
datasets_api=self.datasets_api,
|
||||
)
|
||||
|
||||
resources_allocated, checkpoints = await recipe.train(
|
||||
model=model,
|
||||
output_dir=checkpoint_dir,
|
||||
job_uuid=job_uuid,
|
||||
lora_config=algorithm_config,
|
||||
config=training_config,
|
||||
provider_config=self.config,
|
||||
)
|
||||
|
||||
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
|
||||
if checkpoints:
|
||||
for checkpoint in checkpoints:
|
||||
artifact = self._checkpoint_to_artifact(checkpoint)
|
||||
on_artifact_collected_cb(artifact)
|
||||
|
||||
on_status_change_cb(SchedulerJobStatus.completed)
|
||||
on_log_message_cb("HF finetuning completed")
|
||||
|
||||
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
||||
return PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
finetuned_model: str,
|
||||
algorithm_config: DPOAlignmentConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
) -> PostTrainingJob:
|
||||
raise NotImplementedError("DPO alignment is not implemented yet")
|
||||
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
return ListPostTrainingJobsResponse(
|
||||
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_artifacts_metadata_by_type(job, artifact_type):
|
||||
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
|
||||
|
||||
@classmethod
|
||||
def _get_checkpoints(cls, job):
|
||||
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
|
||||
|
||||
@classmethod
|
||||
def _get_resources_allocated(cls, job):
|
||||
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
||||
return data[0] if data else None
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
|
||||
match job.status:
|
||||
# TODO: Add support for other statuses to API
|
||||
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
|
||||
status = JobStatus.scheduled
|
||||
case SchedulerJobStatus.running:
|
||||
status = JobStatus.in_progress
|
||||
case SchedulerJobStatus.completed:
|
||||
status = JobStatus.completed
|
||||
case SchedulerJobStatus.failed:
|
||||
status = JobStatus.failed
|
||||
case _:
|
||||
raise NotImplementedError()
|
||||
|
||||
return PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=status,
|
||||
scheduled_at=job.scheduled_at,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
checkpoints=self._get_checkpoints(job),
|
||||
resources_allocated=self._get_resources_allocated(job),
|
||||
)
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
self._scheduler.cancel(job_uuid)
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
|
@ -0,0 +1,502 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import psutil
|
||||
|
||||
# Set tokenizer parallelism environment variable
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
)
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
Checkpoint,
|
||||
DataConfig,
|
||||
LoraFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_gb(to_convert: int) -> str:
|
||||
"""Converts memory stats to GB and formats to 2 decimal places.
|
||||
Args:
|
||||
to_convert: Memory value in bytes
|
||||
Returns:
|
||||
str: Memory value in GB formatted to 2 decimal places
|
||||
"""
|
||||
return f"{(to_convert / (1024**3)):.2f}"
|
||||
|
||||
|
||||
def get_memory_stats(device: torch.device) -> dict[str, Any]:
|
||||
"""Get memory statistics for the given device."""
|
||||
stats = {
|
||||
"system_memory": {
|
||||
"total": get_gb(psutil.virtual_memory().total),
|
||||
"available": get_gb(psutil.virtual_memory().available),
|
||||
"used": get_gb(psutil.virtual_memory().used),
|
||||
"percent": psutil.virtual_memory().percent,
|
||||
}
|
||||
}
|
||||
|
||||
if device.type == "cuda":
|
||||
stats["device_memory"] = {
|
||||
"allocated": get_gb(torch.cuda.memory_allocated(device)),
|
||||
"reserved": get_gb(torch.cuda.memory_reserved(device)),
|
||||
"max_allocated": get_gb(torch.cuda.max_memory_allocated(device)),
|
||||
}
|
||||
elif device.type == "mps":
|
||||
# MPS doesn't provide direct memory stats, but we can track system memory
|
||||
stats["device_memory"] = {
|
||||
"note": "MPS memory stats not directly available",
|
||||
"system_memory_used": get_gb(psutil.virtual_memory().used),
|
||||
}
|
||||
elif device.type == "cpu":
|
||||
# For CPU, we track process memory usage
|
||||
process = psutil.Process()
|
||||
stats["device_memory"] = {
|
||||
"process_rss": get_gb(process.memory_info().rss),
|
||||
"process_vms": get_gb(process.memory_info().vms),
|
||||
"process_percent": process.memory_percent(),
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
class HFFinetuningSingleDevice:
|
||||
def __init__(
|
||||
self,
|
||||
job_uuid,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
):
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.job_uuid = job_uuid
|
||||
|
||||
def validate_dataset_format(self, rows: list[dict]) -> bool:
|
||||
"""Validate that the dataset has the required fields."""
|
||||
required_fields = ["input_query", "expected_answer", "chat_completion_input"]
|
||||
return all(field in row for row in rows for field in required_fields)
|
||||
|
||||
def _process_instruct_format(self, row: dict) -> tuple[str | None, str | None]:
|
||||
"""Process a row in instruct format."""
|
||||
if "chat_completion_input" in row and "expected_answer" in row:
|
||||
try:
|
||||
messages = json.loads(row["chat_completion_input"])
|
||||
if not isinstance(messages, list) or len(messages) != 1:
|
||||
logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}")
|
||||
return None, None
|
||||
if "content" not in messages[0]:
|
||||
logger.warning(f"Message missing content: {messages[0]}")
|
||||
return None, None
|
||||
return messages[0]["content"], row["expected_answer"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}")
|
||||
return None, None
|
||||
return None, None
|
||||
|
||||
def _process_dialog_format(self, row: dict) -> tuple[str | None, str | None]:
|
||||
"""Process a row in dialog format."""
|
||||
if "dialog" in row:
|
||||
try:
|
||||
dialog = json.loads(row["dialog"])
|
||||
if not isinstance(dialog, list) or len(dialog) < 2:
|
||||
logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}")
|
||||
return None, None
|
||||
if dialog[0].get("role") != "user":
|
||||
logger.warning(f"First message must be from user: {dialog[0]}")
|
||||
return None, None
|
||||
if not any(msg.get("role") == "assistant" for msg in dialog):
|
||||
logger.warning("Dialog must have at least one assistant message")
|
||||
return None, None
|
||||
|
||||
# Convert to human/gpt format
|
||||
role_map = {"user": "human", "assistant": "gpt"}
|
||||
conversations = []
|
||||
for msg in dialog:
|
||||
if "role" not in msg or "content" not in msg:
|
||||
logger.warning(f"Message missing role or content: {msg}")
|
||||
continue
|
||||
conversations.append({"from": role_map[msg["role"]], "value": msg["content"]})
|
||||
|
||||
# Format as a single conversation
|
||||
return conversations[0]["value"], conversations[1]["value"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse dialog: {row['dialog']}")
|
||||
return None, None
|
||||
return None, None
|
||||
|
||||
def _process_fallback_format(self, row: dict) -> tuple[str | None, str | None]:
|
||||
"""Process a row using fallback formats."""
|
||||
if "input" in row and "output" in row:
|
||||
return row["input"], row["output"]
|
||||
elif "prompt" in row and "completion" in row:
|
||||
return row["prompt"], row["completion"]
|
||||
elif "question" in row and "answer" in row:
|
||||
return row["question"], row["answer"]
|
||||
return None, None
|
||||
|
||||
def _format_text(self, input_text: str, output_text: str, provider_config: HuggingFacePostTrainingConfig) -> str:
|
||||
"""Format input and output text based on model requirements."""
|
||||
if hasattr(provider_config, "chat_template"):
|
||||
return provider_config.chat_template.format(input=input_text, output=output_text)
|
||||
return f"{input_text}\n{output_text}"
|
||||
|
||||
def _create_dataset(
|
||||
self, rows: list[dict], config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig
|
||||
) -> Dataset:
|
||||
"""Create and preprocess the dataset."""
|
||||
formatted_rows = []
|
||||
for row in rows:
|
||||
input_text = None
|
||||
output_text = None
|
||||
|
||||
# Process based on format
|
||||
assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized"
|
||||
if config.data_config.data_format.value == "instruct":
|
||||
input_text, output_text = self._process_instruct_format(row)
|
||||
elif config.data_config.data_format.value == "dialog":
|
||||
input_text, output_text = self._process_dialog_format(row)
|
||||
else:
|
||||
input_text, output_text = self._process_fallback_format(row)
|
||||
|
||||
if input_text and output_text:
|
||||
formatted_text = self._format_text(input_text, output_text, provider_config)
|
||||
formatted_rows.append({"text": formatted_text})
|
||||
|
||||
if not formatted_rows:
|
||||
assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized"
|
||||
raise ValueError(
|
||||
f"No valid input/output pairs found in the dataset for format: {config.data_config.data_format.value}"
|
||||
)
|
||||
|
||||
return Dataset.from_list(formatted_rows)
|
||||
|
||||
def _preprocess_dataset(
|
||||
self, ds: Dataset, tokenizer: AutoTokenizer, provider_config: HuggingFacePostTrainingConfig
|
||||
) -> Dataset:
|
||||
"""Preprocess the dataset with tokenizer."""
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(
|
||||
examples["text"],
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=provider_config.max_seq_length,
|
||||
return_tensors=None,
|
||||
)
|
||||
|
||||
return ds.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
remove_columns=ds.column_names,
|
||||
)
|
||||
|
||||
async def load_dataset(
|
||||
self,
|
||||
model: str,
|
||||
config: TrainingConfig,
|
||||
provider_config: HuggingFacePostTrainingConfig,
|
||||
) -> tuple[Dataset, Dataset, AutoTokenizer]:
|
||||
"""Load and preprocess the dataset for training.
|
||||
|
||||
Args:
|
||||
model: The model identifier to load
|
||||
config: Training configuration containing dataset settings
|
||||
provider_config: Provider-specific configuration
|
||||
|
||||
Returns:
|
||||
tuple containing:
|
||||
- Training dataset
|
||||
- Evaluation dataset
|
||||
- Tokenizer
|
||||
|
||||
Raises:
|
||||
ValueError: If dataset is missing required fields
|
||||
RuntimeError: If tokenizer initialization fails
|
||||
"""
|
||||
assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized"
|
||||
rows = await self._setup_data(config.data_config.dataset_id)
|
||||
|
||||
# Validate that the dataset has the required fields for training
|
||||
if not self.validate_dataset_format(rows):
|
||||
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
||||
|
||||
# Initialize tokenizer with model-specific config
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
|
||||
# Set up tokenizer defaults
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "right"
|
||||
tokenizer.truncation_side = "right"
|
||||
tokenizer.model_max_length = provider_config.max_seq_length
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
|
||||
|
||||
# Create and preprocess dataset
|
||||
try:
|
||||
ds = self._create_dataset(rows, config, provider_config)
|
||||
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to create dataset: {str(e)}") from e
|
||||
|
||||
# Split dataset into train and validation
|
||||
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
|
||||
return train_val_split["train"], train_val_split["test"], tokenizer
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
model: str,
|
||||
device: torch.device,
|
||||
provider_config: HuggingFacePostTrainingConfig,
|
||||
) -> AutoModelForCausalLM:
|
||||
"""Load and initialize the model for training.
|
||||
|
||||
Args:
|
||||
model: The model identifier to load
|
||||
device: The device to load the model onto
|
||||
provider_config: Provider-specific configuration
|
||||
|
||||
Returns:
|
||||
The loaded and initialized model
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
logger.info("Loading the base model")
|
||||
try:
|
||||
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
|
||||
model_obj = AutoModelForCausalLM.from_pretrained(
|
||||
model,
|
||||
torch_dtype="auto",
|
||||
quantization_config=None,
|
||||
config=model_config,
|
||||
**provider_config.model_specific_config,
|
||||
)
|
||||
if model_obj.device != device:
|
||||
model_obj = model_obj.to(device)
|
||||
logger.info(f"Model loaded and moved to device: {model_obj.device}")
|
||||
return model_obj
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load model: {str(e)}") from e
|
||||
|
||||
async def train(
|
||||
self,
|
||||
model: str,
|
||||
output_dir: str | None,
|
||||
job_uuid: str,
|
||||
lora_config: LoraFinetuningConfig,
|
||||
config: TrainingConfig,
|
||||
provider_config: HuggingFacePostTrainingConfig,
|
||||
) -> tuple[dict[str, Any], list[Checkpoint] | None]:
|
||||
"""Train a model using HuggingFace's SFTTrainer"""
|
||||
try:
|
||||
device = torch.device(provider_config.device)
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(f"Error getting Torch Device {str(e)}") from e
|
||||
|
||||
# Detect device type and validate
|
||||
if device.type == "cuda":
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(
|
||||
f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device."
|
||||
)
|
||||
# map unqualified 'cuda' to current device
|
||||
if device.index is None:
|
||||
device = torch.device(device.type, torch.cuda.current_device())
|
||||
elif device.type == "mps":
|
||||
if not torch.backends.mps.is_available():
|
||||
raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.")
|
||||
elif device.type == "hpu":
|
||||
raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.")
|
||||
|
||||
logger.info(f"Using device '{device}'")
|
||||
output_dir_path = None
|
||||
if output_dir:
|
||||
output_dir_path = Path(output_dir)
|
||||
|
||||
# Track memory stats throughout training
|
||||
memory_stats = {
|
||||
"initial": get_memory_stats(device),
|
||||
"after_model_load": None,
|
||||
"after_training": None,
|
||||
"final": None,
|
||||
}
|
||||
|
||||
# Validate data config
|
||||
if not config.data_config:
|
||||
raise ValueError("DataConfig is required for training")
|
||||
|
||||
# Load dataset and tokenizer
|
||||
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config, provider_config)
|
||||
|
||||
# Load model with model-specific config
|
||||
model_obj = self.load_model(model, device, provider_config)
|
||||
memory_stats["after_model_load"] = get_memory_stats(device)
|
||||
|
||||
# Configure LoRA
|
||||
peft_config = None
|
||||
if lora_config:
|
||||
peft_config = LoraConfig(
|
||||
lora_alpha=lora_config.alpha,
|
||||
lora_dropout=0.1,
|
||||
r=lora_config.rank,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=lora_config.lora_attn_modules,
|
||||
)
|
||||
|
||||
# Setup training arguments
|
||||
lr = 2e-5
|
||||
if config.optimizer_config:
|
||||
lr = config.optimizer_config.lr
|
||||
|
||||
# Calculate steps per epoch and appropriate intervals
|
||||
steps_per_epoch = len(train_dataset) // config.data_config.batch_size
|
||||
eval_steps = max(1, steps_per_epoch // 10) # Evaluate 10 times per epoch
|
||||
save_steps = max(1, steps_per_epoch // 5) # Save 5 times per epoch
|
||||
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
|
||||
|
||||
logger.info(f"Dataset size: {len(train_dataset)} examples")
|
||||
logger.info(f"Batch size: {config.data_config.batch_size}")
|
||||
logger.info(f"Steps per epoch: {steps_per_epoch}")
|
||||
logger.info(f"Will evaluate every {eval_steps} steps")
|
||||
logger.info(f"Will save every {save_steps} steps")
|
||||
logger.info(f"Will log every {logging_steps} steps")
|
||||
|
||||
# save_strategy should be none if output dir is none
|
||||
save_strategy = "no"
|
||||
if output_dir_path:
|
||||
save_strategy = "steps"
|
||||
training_arguments = SFTConfig(
|
||||
max_steps=config.max_steps_per_epoch,
|
||||
output_dir=str(output_dir_path) if output_dir_path is not None else None,
|
||||
num_train_epochs=config.n_epochs,
|
||||
per_device_train_batch_size=config.data_config.batch_size,
|
||||
fp16=device.type == "cuda",
|
||||
bf16=device.type != "cuda",
|
||||
# use_cpu should only be set if we are on a "True" CPU machine, not a MPS enabled Mac due to stability issues.
|
||||
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
|
||||
save_strategy=save_strategy,
|
||||
save_steps=save_steps,
|
||||
report_to="none",
|
||||
max_seq_length=provider_config.max_seq_length,
|
||||
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
||||
gradient_checkpointing=provider_config.gradient_checkpointing,
|
||||
learning_rate=lr,
|
||||
warmup_ratio=provider_config.warmup_ratio,
|
||||
weight_decay=provider_config.weight_decay,
|
||||
logging_steps=logging_steps,
|
||||
# Enable validation
|
||||
eval_strategy="steps",
|
||||
eval_steps=eval_steps,
|
||||
save_total_limit=provider_config.save_total_limit,
|
||||
remove_unused_columns=False,
|
||||
dataloader_pin_memory=provider_config.dataloader_pin_memory,
|
||||
dataloader_num_workers=provider_config.dataloader_num_workers,
|
||||
dataset_text_field="text",
|
||||
packing=False,
|
||||
# Add evaluation metrics
|
||||
# loading the best model can only happen if we have saved a model
|
||||
load_best_model_at_end=True if output_dir_path else False,
|
||||
metric_for_best_model="eval_loss",
|
||||
greater_is_better=False,
|
||||
)
|
||||
|
||||
# Initialize trainer with both train and eval datasets
|
||||
trainer = SFTTrainer(
|
||||
model=model_obj,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
peft_config=peft_config,
|
||||
args=training_arguments,
|
||||
)
|
||||
|
||||
# Train
|
||||
logger.info("Starting training")
|
||||
try:
|
||||
trainer.train()
|
||||
memory_stats["after_training"] = get_memory_stats(device)
|
||||
|
||||
# Save final model
|
||||
model_obj.config.use_cache = True
|
||||
# if we have LoRA we need to do `merge_and_unload`
|
||||
if lora_config:
|
||||
model_obj = trainer.model.merge_and_unload()
|
||||
else:
|
||||
model_obj = trainer.model
|
||||
|
||||
checkpoint = None
|
||||
checkpoints = None
|
||||
# only save a final model if checkpoint dir is specified
|
||||
# this is especially useful to test training rather than saving of checkpoints
|
||||
if output_dir_path:
|
||||
model_obj.save_pretrained(output_dir_path / "merged_model")
|
||||
|
||||
# Create checkpoint
|
||||
checkpoint = Checkpoint(
|
||||
identifier=f"{model}-sft-{config.n_epochs}",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
epoch=config.n_epochs,
|
||||
post_training_job_id=job_uuid,
|
||||
path=str(output_dir_path / "merged_model"),
|
||||
)
|
||||
checkpoints = [checkpoint]
|
||||
|
||||
return memory_stats, checkpoints
|
||||
finally:
|
||||
# Clean up resources
|
||||
if hasattr(trainer, "model"):
|
||||
if device.type != "cpu":
|
||||
trainer.model.to("cpu")
|
||||
if device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
del trainer.model
|
||||
del trainer
|
||||
gc.collect()
|
||||
memory_stats["final"] = get_memory_stats(device)
|
||||
|
||||
async def _setup_data(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Load dataset from llama stack dataset provider"""
|
||||
try:
|
||||
|
||||
async def fetch_rows(dataset_id: str):
|
||||
return await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
limit=-1,
|
||||
)
|
||||
|
||||
all_rows = await fetch_rows(dataset_id)
|
||||
if not isinstance(all_rows.data, list):
|
||||
raise RuntimeError("Expected dataset data to be a list")
|
||||
return all_rows.data
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
|
|
@ -21,6 +21,17 @@ def available_providers() -> list[ProviderSpec]:
|
|||
Api.datasets,
|
||||
],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.post_training,
|
||||
provider_type="inline::huggingface",
|
||||
pip_packages=["torch", "trl", "transformers", "peft", "datasets"],
|
||||
module="llama_stack.providers.inline.post_training.huggingface",
|
||||
config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
],
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.post_training,
|
||||
adapter=AdapterSpec(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue