diff --git a/llama_stack/providers/inline/post_training/huggingface/config.py b/llama_stack/providers/inline/post_training/huggingface/config.py index 06c6d8073..e37c51e6e 100644 --- a/llama_stack/providers/inline/post_training/huggingface/config.py +++ b/llama_stack/providers/inline/post_training/huggingface/config.py @@ -57,7 +57,7 @@ class HuggingFacePostTrainingConfig(BaseModel): # L2 regularization coefficient # Helps prevent overfitting - weight_decay: float = 0.01 + weight_decay: float = 0.00 # Number of worker processes for data loading # Higher values can improve data loading speed but increase memory usage @@ -67,6 +67,17 @@ class HuggingFacePostTrainingConfig(BaseModel): # Can improve data transfer speed to GPU but uses more memory dataloader_pin_memory: bool = True + # Recipe type for training (single or multi device) + recipe: str = "single" + + # NCCL debug configuration for distributed training + # Enable detailed NCCL logging for debugging distributed training issues + enable_nccl_debug: bool = False + + # NCCL subsystems to debug (NONE, ALL, INIT, COLL, P2P, SHM, NET) + # Controls which NCCL components generate debug output + nccl_debug_subsys: str = "NONE" + @classmethod def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: - return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"} + return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu", "recipe": "single"} diff --git a/llama_stack/providers/inline/post_training/huggingface/finetune_handler.py b/llama_stack/providers/inline/post_training/huggingface/finetune_handler.py new file mode 100755 index 000000000..d993491f5 --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface/finetune_handler.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import asyncio +import json +import os +from typing import Any + +from llama_stack.apis.post_training import TrainingConfig +from llama_stack.providers.inline.post_training.huggingface.config import HuggingFacePostTrainingConfig +from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_multi_device import ( + HFFinetuningMultiDevice, +) +from llama_stack.providers.utils.scheduler import JobStatus + + +async def train( + job_uuid, + model, + checkpoint_dir, + training_config, + provider_config, + algorithm_config, + data, + enable_nccl_debug=False, + nccl_debug_subsys="NONE", +): + """Handler function for HuggingFace training that can be called by torchrun. + + This is extracted from the supervised_fine_tune method in the HuggingFacePostTrainingImpl class. + It follows the same flow, but is designed to be called directly from a script. + + Args: + job_uuid: Unique ID for this job + model: Model to train + checkpoint_dir: Directory to save checkpoints to + training_config: Training configuration + provider_config: Provider configuration + algorithm_config: Algorithm configuration + data: the dataset rows to be processed + enable_nccl_debug: Whether to enable NCCL debugging + nccl_debug_subsys: NCCL subsystem to debug + """ + # Get rank information when running distributed + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + + parsed_data: list[dict[str, Any]] = json.loads(data) + + # Set up callback functions with rank information + def on_log_message_cb(msg): + print(f"[RANK {local_rank}] {msg}", flush=True) + + def on_status_change_cb(status): + print(f"[RANK {local_rank}] Status: {status}", flush=True) + + def on_artifact_collected_cb(artifact): + print(f"[RANK {local_rank}] Artifact: {artifact}", flush=True) + + on_log_message_cb("Starting HF finetuning") + + recipe_obj = HFFinetuningMultiDevice( + job_uuid=job_uuid, enable_nccl_debug=enable_nccl_debug, nccl_debug_subsys=nccl_debug_subsys, data=parsed_data + ) + + resources_allocated, checkpoints = await recipe_obj.train( + model=model, + output_dir=checkpoint_dir, + job_uuid=job_uuid, + lora_config=algorithm_config, + config=training_config, + provider_config=provider_config, + ) + + def resources_stats_to_artifact(resources_stats): + return { + "type": "resources_stats", + "name": "resources_stats", + "metadata": resources_stats, + } + + def checkpoint_to_artifact(checkpoint): + return { + "type": "checkpoint", + "name": checkpoint.identifier, + "uri": checkpoint.path, + "metadata": dict(checkpoint), + } + + on_artifact_collected_cb(resources_stats_to_artifact(resources_allocated)) + if checkpoints: + for checkpoint in checkpoints: + artifact = checkpoint_to_artifact(checkpoint) + on_artifact_collected_cb(artifact) + + on_status_change_cb(JobStatus.completed) + on_log_message_cb("HF finetuning completed") + + +async def main(): + parser = argparse.ArgumentParser(description="Run HuggingFace training with torchrun.") + parser.add_argument("--job_uuid", type=str, required=True, help="Job UUID") + parser.add_argument("--model", type=str, required=True, help="Model to use") + parser.add_argument("--checkpoint_dir", type=str, help="Directory to save checkpoints") + parser.add_argument("--training_config", type=str, required=True, help="Training config JSON") + parser.add_argument("--provider_config", type=str, required=True, help="Provider config JSON") + parser.add_argument("--algorithm_config", type=str, help="Algorithm config JSON") + parser.add_argument("--enable_nccl_debug", action="store_true", help="Enable NCCL debugging") + parser.add_argument("--nccl_debug_subsys", type=str, default="NONE", help="NCCL subsystem to debug") + parser.add_argument("--data", type=str, required=True) + + args = parser.parse_args() + + # Parse JSON configs + try: + training_config = TrainingConfig.model_validate_json(args.training_config) + except Exception as e: + print(f"Error parsing training_config: {e}") + print(f"Received: {args.training_config}") + raise + + try: + provider_config = HuggingFacePostTrainingConfig.model_validate_json(args.provider_config) + except Exception as e: + print(f"Error parsing provider_config: {e}") + print(f"Received: {args.provider_config}") + raise + + algorithm_config = None + if args.algorithm_config: + try: + algorithm_config = json.loads(args.algorithm_config) + except json.JSONDecodeError as e: + print(f"Error parsing algorithm_config: {e}") + print(f"Received: {args.algorithm_config}") + raise + + # In a real implementation, you would get these from somewhere + # For now, we'll pass None and handle it in the train function + datasetio_api = None + datasets_api = None + + # Print arguments for debugging + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if local_rank == 0: # Only the main process prints + print("Starting training with arguments:") + print(f" job_uuid: {args.job_uuid}") + print(f" model: {args.model}") + print(f" checkpoint_dir: {args.checkpoint_dir}") + print(f" enable_nccl_debug: {args.enable_nccl_debug}") + print(f" nccl_debug_subsys: {args.nccl_debug_subsys}") + + await train( + job_uuid=args.job_uuid, + model=args.model, + checkpoint_dir=args.checkpoint_dir, + training_config=training_config, + provider_config=provider_config, + algorithm_config=algorithm_config, + datasetio_api=datasetio_api, + datasets_api=datasets_api, + enable_nccl_debug=args.enable_nccl_debug, + nccl_debug_subsys=args.nccl_debug_subsys, + data=args.data, + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/llama_stack/providers/inline/post_training/huggingface/post_training.py b/llama_stack/providers/inline/post_training/huggingface/post_training.py index 0b2760792..7f0d5d1a3 100644 --- a/llama_stack/providers/inline/post_training/huggingface/post_training.py +++ b/llama_stack/providers/inline/post_training/huggingface/post_training.py @@ -3,6 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json from enum import Enum from typing import Any @@ -22,6 +23,7 @@ from llama_stack.apis.post_training import ( from llama_stack.providers.inline.post_training.huggingface.config import ( HuggingFacePostTrainingConfig, ) +from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_multi_device import HFFinetuningMultiDevice from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import ( HFFinetuningSingleDevice, ) @@ -80,35 +82,130 @@ class HuggingFacePostTrainingImpl: checkpoint_dir: str | None = None, algorithm_config: AlgorithmConfig | None = None, ) -> PostTrainingJob: - async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb): - on_log_message_cb("Starting HF finetuning") + from collections.abc import Callable, Coroutine + from typing import Any - recipe = HFFinetuningSingleDevice( - job_uuid=job_uuid, - datasetio_api=self.datasetio_api, - datasets_api=self.datasets_api, + # Type for the handler: async fn taking 3 Any args, returns Awaitable[None] + handler: ( + Callable[ + [Callable[[str], None], Callable[[SchedulerJobStatus], None], Callable[[JobArtifact], None]], + Coroutine[Any, Any, None], + ] + | None + ) = None + + # Determine world size for distributed training + world_size = getattr(self.config, "world_size", 1) + + # Choose the backend and recipe based on world size + if world_size > 1: + recipe = "multi" + + # Create parameters for the handler script + run_params = { + "job_uuid": job_uuid, + "model": model, + "world_size": world_size, + "recipe": recipe, + } + + # Add optional parameters + if checkpoint_dir is not None: + run_params["checkpoint_dir"] = checkpoint_dir + + if training_config is not None: + run_params["training_config"] = training_config.model_dump_json() + + if algorithm_config is not None: + run_params["algorithm_config"] = algorithm_config.model_dump_json() + + # Add provider-specific configuration + run_params["provider_config"] = self.config.model_dump_json() + + # Add NCCL debug settings if present + if hasattr(self.config, "enable_nccl_debug"): + run_params["enable_nccl_debug"] = self.config.enable_nccl_debug + + if hasattr(self.config, "nccl_debug_subsys"): + run_params["nccl_debug_subsys"] = self.config.nccl_debug_subsys + + # Initialize the scheduler with the distributed backend + self._scheduler = Scheduler(backend="distributed") + else: + self._scheduler = Scheduler(backend="naive") + + # TODO: this can probably be cleaner + # Single-device training path + # Define a handler for single-device training + async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb): + on_log_message_cb("Starting HF finetuning") + + recipe = HFFinetuningSingleDevice( + job_uuid=job_uuid, + datasetio_api=self.datasetio_api, + datasets_api=self.datasets_api, + ) + if self.config.recipe == "multi": + recipe = HFFinetuningMultiDevice( + job_uuid=job_uuid, + datasetio_api=self.datasetio_api, + datasets_api=self.datasets_api, + enable_nccl_debug=getattr(self.config, "enable_nccl_debug", False), + nccl_debug_subsys=getattr(self.config, "nccl_debug_subsys", "NONE"), + ) + + resources_allocated, checkpoints = await recipe.train( + model=model, + output_dir=checkpoint_dir, + job_uuid=job_uuid, + lora_config=algorithm_config, + config=training_config, + provider_config=self.config, + ) + + on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated)) + if checkpoints: + for checkpoint in checkpoints: + artifact = self._checkpoint_to_artifact(checkpoint) + on_artifact_collected_cb(artifact) + + on_status_change_cb(SchedulerJobStatus.completed) + on_log_message_cb("HF finetuning completed") + + assert training_config.data_config is not None + data = self._setup_data(dataset_id=training_config.data_config.dataset_id) + + json_data = json.dumps(data) + + run_params["data"] = json_data + + # Schedule the job with the regular scheduler and the handler + job_id = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler, run_params) + + return PostTrainingJob(job_uuid=job_id) + + async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]: + """Load dataset from llama stack dataset provider. + + Args: + dataset_id: ID of the dataset to load + + Returns: + list: List of dataset rows + + Raises: + RuntimeError: If dataset loading fails + """ + try: + all_rows = await self.datasetio_api.iterrows( + dataset_id=dataset_id, + limit=-1, ) - - resources_allocated, checkpoints = await recipe.train( - model=model, - output_dir=checkpoint_dir, - job_uuid=job_uuid, - lora_config=algorithm_config, - config=training_config, - provider_config=self.config, - ) - - on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated)) - if checkpoints: - for checkpoint in checkpoints: - artifact = self._checkpoint_to_artifact(checkpoint) - on_artifact_collected_cb(artifact) - - on_status_change_cb(SchedulerJobStatus.completed) - on_log_message_cb("HF finetuning completed") - - job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler) - return PostTrainingJob(job_uuid=job_uuid) + if not isinstance(all_rows.data, list): + raise RuntimeError("Expected dataset data to be a list") + return all_rows.data + except Exception as e: + raise RuntimeError(f"Failed to load dataset: {str(e)}") from e async def preference_optimize( self, diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_multi_device.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_multi_device.py new file mode 100644 index 000000000..ab888c42c --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_multi_device.py @@ -0,0 +1,1063 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import gc +import json +import logging +import os +import signal +import sys +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import psutil +import torch +import torch.distributed as dist +from peft import LoraConfig, LoraModel +from torch.distributed.fsdp import FullStateDictConfig, FullyShardedDataParallel, StateDictType + +from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device + +# Set tokenizer parallelism environment variable +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# Force PyTorch to use OpenBLAS instead of MKL +os.environ["MKL_THREADING_LAYER"] = "GNU" +os.environ["MKL_SERVICE_FORCE_INTEL"] = "0" +os.environ["MKL_NUM_THREADS"] = "1" + + +def configure_nccl_logging( + debug: bool = False, + debug_subsys: str = "NONE", + socket_timeout: int = 1200, + async_error_handling: bool = True, + blocking_wait: bool = True, + ib_timeout: int = 120, + net_gdr_level: int = 5, + cuda_launch_blocking: bool = True, +) -> None: + """Configure NCCL environment variables for distributed training. + Args: + debug: Enable NCCL debug logging + debug_subsys: NCCL subsystems to debug (ALL, INIT, COLL, P2P, SHM, NET, etc.) + socket_timeout: Socket timeout in seconds + async_error_handling: Enable async error handling + blocking_wait: Use blocking wait + ib_timeout: InfiniBand timeout in seconds + net_gdr_level: GPU Direct RDMA level (0-5) + cuda_launch_blocking: Enable CUDA launch blocking + """ + # Set NCCL environment variables + os.environ["NCCL_DEBUG"] = "INFO" if debug else "WARN" + os.environ["NCCL_DEBUG_SUBSYS"] = debug_subsys + os.environ["NCCL_SOCKET_TIMEOUT"] = str(socket_timeout) + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" if async_error_handling else "0" + os.environ["NCCL_BLOCKING_WAIT"] = "1" if blocking_wait else "0" + os.environ["NCCL_IB_TIMEOUT"] = str(ib_timeout) + os.environ["NCCL_NET_GDR_LEVEL"] = str(net_gdr_level) + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" if cuda_launch_blocking else "0" + + +# Configure NCCL with default settings (minimal logging) +configure_nccl_logging() + +from datasets import Dataset +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + PreTrainedTokenizer, +) +from trl import SFTConfig, SFTTrainer + +from llama_stack.apis.post_training import ( + Checkpoint, + DataConfig, + LoraFinetuningConfig, + TrainingConfig, +) + +from ..config import HuggingFacePostTrainingConfig + +logger = logging.getLogger(__name__) + + +def get_gb(to_convert: int) -> str: + """Converts memory stats to GB and formats to 2 decimal places. + + Args: + to_convert: Memory value in bytes + + Returns: + str: Memory value in GB formatted to 2 decimal places + """ + return f"{(to_convert / (1024**3)):.2f}" + + +def get_memory_stats(device: torch.device) -> dict[str, Any]: + """Get memory statistics for the given device. + + This function collects memory statistics for both system and device memory. + For CUDA devices, it tracks allocated, reserved, and max allocated memory. + For MPS devices, it tracks system memory usage since direct device stats aren't available. + For CPU devices, it tracks process memory usage. + + Args: + device: The device to get memory stats for (cuda, mps, or cpu) + + Returns: + dict: Dictionary containing memory statistics for both system and device + """ + stats = { + "system_memory": { + "total": get_gb(psutil.virtual_memory().total), + "available": get_gb(psutil.virtual_memory().available), + "used": get_gb(psutil.virtual_memory().used), + "percent": psutil.virtual_memory().percent, + } + } + + if device.type == "cuda": + stats["device_memory"] = { + "allocated": get_gb(torch.cuda.memory_allocated(device)), + "reserved": get_gb(torch.cuda.memory_reserved(device)), + "max_allocated": get_gb(torch.cuda.max_memory_allocated(device)), + } + elif device.type == "mps": + # MPS doesn't provide direct memory stats, but we can track system memory + stats["device_memory"] = { + "note": "MPS memory stats not directly available", + "system_memory_used": get_gb(psutil.virtual_memory().used), + } + elif device.type == "cpu": + # For CPU, we track process memory usage + process = psutil.Process() + stats["device_memory"] = { + "process_rss": get_gb(process.memory_info().rss), + "process_vms": get_gb(process.memory_info().vms), + "process_percent": process.memory_percent(), + } + + return stats + + +def setup_distributed_training(device_str: str) -> tuple[int, int]: + """Initialize distributed training environment. + + This function sets up the distributed training environment by: + 1. Parsing the device list to determine number of GPUs + 2. Setting up environment variables for distributed training + 3. Initializing the process group with NCCL backend + + Args: + device_str: Comma-separated list of devices (e.g. "cuda:0,cuda:1") + + Returns: + tuple: (local_rank, world_size) where: + - local_rank is the rank of this process (0 for single device) + - world_size is the total number of processes (1 for single device) + """ + # Parse device list + devices = [d.strip() for d in device_str.split(",")] + world_size = len(devices) + + if world_size <= 1: + logger.info("Single device training") + return 0, 1 + + # Set up environment variables for distributed training + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["RANK"] = "0" # We're the main process + os.environ["LOCAL_RANK"] = "0" + + # Initialize process group + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + logger.info(f"Initialized distributed training with {world_size} devices: {devices}") + + return 0, world_size + + +class HFFinetuningMultiDevice: + def __init__( + self, + job_uuid: str, + data: list[dict[str, Any]], + enable_nccl_debug: bool = False, + nccl_debug_subsys: str = "NONE", + ): + """Initialize the multi-device fine-tuning handler. + + Args: + job_uuid: Unique identifier for this training job + datasetio_api: API for dataset I/O operations + datasets_api: API for dataset management + """ + self.data = data + self.job_uuid = job_uuid + self.enable_nccl_debug = enable_nccl_debug + self.nccl_debug_subsys = nccl_debug_subsys + + def validate_dataset_format(self, rows: list[dict]) -> bool: + """Validate that the dataset has the required fields. + + Args: + rows: List of dataset rows to validate + + Returns: + bool: True if all rows have required fields, False otherwise + """ + required_fields = ["input_query", "expected_answer", "chat_completion_input"] + return all(field in row for row in rows for field in required_fields) + + def _process_instruct_format(self, row: dict) -> tuple[str | None, str | None]: + """Process a row in instruct format. + + Args: + row: Dataset row containing chat completion input and expected answer + + Returns: + tuple: (input_text, output_text) or (None, None) if invalid format + """ + if "chat_completion_input" in row and "expected_answer" in row: + try: + messages = json.loads(row["chat_completion_input"]) + if not isinstance(messages, list) or len(messages) != 1: + logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}") + return None, None + if "content" not in messages[0]: + logger.warning(f"Message missing content: {messages[0]}") + return None, None + return messages[0]["content"], row["expected_answer"] + except json.JSONDecodeError: + logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}") + return None, None + return None, None + + def _process_dialog_format(self, row: dict) -> tuple[str | None, str | None]: + """Process a row in dialog format. + + Args: + row: Dataset row containing dialog messages + + Returns: + tuple: (input_text, output_text) or (None, None) if invalid format + """ + if "dialog" in row: + try: + dialog = json.loads(row["dialog"]) + if not isinstance(dialog, list) or len(dialog) < 2: + logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}") + return None, None + if dialog[0].get("role") != "user": + logger.warning(f"First message must be from user: {dialog[0]}") + return None, None + if not any(msg.get("role") == "assistant" for msg in dialog): + logger.warning("Dialog must have at least one assistant message") + return None, None + + # Convert to human/gpt format + role_map = {"user": "human", "assistant": "gpt"} + conversations = [] + for msg in dialog: + if "role" not in msg or "content" not in msg: + logger.warning(f"Message missing role or content: {msg}") + continue + conversations.append({"from": role_map[msg["role"]], "value": msg["content"]}) + + # Format as a single conversation + return conversations[0]["value"], conversations[1]["value"] + except json.JSONDecodeError: + logger.warning(f"Failed to parse dialog: {row['dialog']}") + return None, None + return None, None + + def _process_fallback_format(self, row: dict) -> tuple[str | None, str | None]: + """Process a row using fallback formats. + + Args: + row: Dataset row to process + + Returns: + tuple: (input_text, output_text) or (None, None) if no valid format found + """ + if "input" in row and "output" in row: + return row["input"], row["output"] + elif "prompt" in row and "completion" in row: + return row["prompt"], row["completion"] + elif "question" in row and "answer" in row: + return row["question"], row["answer"] + return None, None + + def _format_text(self, input_text: str, output_text: str, provider_config: HuggingFacePostTrainingConfig) -> str: + """Format input and output text based on model requirements. + + Args: + input_text: The input text to format + output_text: The output text to format + provider_config: Configuration containing chat template + + Returns: + str: Formatted text using the chat template + """ + if hasattr(provider_config, "chat_template"): + return provider_config.chat_template.format(input=input_text, output=output_text) + return f"{input_text}\n{output_text}" + + def _create_dataset( + self, rows: list[dict], config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig + ) -> Dataset: + """Create and preprocess the dataset. + + Args: + rows: List of dataset rows to process + config: Training configuration containing data format + provider_config: Provider-specific configuration + + Returns: + Dataset: Processed dataset ready for training + + Raises: + ValueError: If no valid input/output pairs found for the specified format + """ + formatted_rows = [] + for row in rows: + input_text = None + output_text = None + + # Process based on format + assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized" + if config.data_config.data_format.value == "instruct": + input_text, output_text = self._process_instruct_format(row) + elif config.data_config.data_format.value == "dialog": + input_text, output_text = self._process_dialog_format(row) + else: + input_text, output_text = self._process_fallback_format(row) + + if input_text and output_text: + formatted_text = self._format_text(input_text, output_text, provider_config) + formatted_rows.append({"text": formatted_text}) + + if not formatted_rows: + assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized" + raise ValueError( + f"No valid input/output pairs found in the dataset for format: {config.data_config.data_format.value}" + ) + + return Dataset.from_list(formatted_rows) + + def _preprocess_dataset( + self, ds: Dataset, tokenizer: AutoTokenizer, provider_config: HuggingFacePostTrainingConfig + ) -> Dataset: + """Preprocess the dataset with tokenizer. + + Args: + ds: Dataset to preprocess + tokenizer: Tokenizer to use for preprocessing + provider_config: Provider-specific configuration + + Returns: + Dataset: Tokenized and preprocessed dataset + """ + + def tokenize_function(examples): + # Ensure consistent padding and truncation + outputs = tokenizer( + examples["text"], + padding="max_length", # Use max_length padding for consistent dimensions + truncation=True, + max_length=provider_config.max_seq_length, + return_tensors=None, # Don't return tensors yet + return_attention_mask=True, + return_token_type_ids=False, + ) + # Add labels for causal language modeling + outputs["labels"] = outputs["input_ids"].copy() + + # Verify dimensions + assert all(len(x) == provider_config.max_seq_length for x in outputs["input_ids"]), ( + "Inconsistent input_ids length" + ) + assert all(len(x) == provider_config.max_seq_length for x in outputs["attention_mask"]), ( + "Inconsistent attention_mask length" + ) + assert all(len(x) == provider_config.max_seq_length for x in outputs["labels"]), ( + "Inconsistent labels length" + ) + + return outputs + + # Process in batches + return ds.map( + tokenize_function, + batched=True, + batch_size=1000, # Process in larger batches for efficiency + remove_columns=ds.column_names, + desc="Tokenizing and preparing dataset", + num_proc=1, # Single process to avoid issues + ) + + def _run_training_sync( + self, + local_rank: int, # First parameter must be local_rank for spawn + world_size: int, # Second parameter must be world_size for spawn + model: str, + provider_config: dict[str, Any], + peft_config: LoraConfig | None, + config: dict[str, Any], + output_dir_path: Path | None, + ) -> None: + """Synchronous wrapper for running training process. + + This method serves as a bridge between the multiprocessing Process and the async training function. + It creates a new event loop to run the async training process. + + Args: + local_rank: Local rank of this process (0 to world_size-1) + world_size: Total number of processes + model: The model identifier to load + provider_config: Configuration specific to the HuggingFace provider + peft_config: Optional LoRA configuration + config: General training configuration + output_dir_path: Optional path to save the model + """ + import asyncio + + logger.info("Starting training process with async wrapper") + asyncio.run( + self._run_training( + model=model, + provider_config=provider_config, + peft_config=peft_config, + config=config, + output_dir_path=output_dir_path, + local_rank=local_rank, + world_size=world_size, + ) + ) + + async def _run_training( + self, + model: str, + provider_config: dict[str, Any], + peft_config: LoraConfig | None, + config: dict[str, Any], + output_dir_path: Path | None, + local_rank: int, + world_size: int, + ) -> None: + """Run the training process with signal handling. + + This method handles the actual training process, including: + 1. Setting up signal handlers for graceful shutdown + 2. Initializing distributed training environment + 3. Loading and preprocessing the dataset + 4. Loading and configuring the model + 5. Setting up and running the trainer + 6. Saving the final model + 7. Cleaning up resources + + Args: + model: The model identifier to load + provider_config: Configuration specific to the HuggingFace provider + peft_config: Optional LoRA configuration + config: General training configuration + output_dir_path: Optional path to save the model + local_rank: Local rank of this process + world_size: Total number of processes + """ + + def signal_handler(signum, frame): + """Handle termination signals gracefully.""" + logger.info(f"Received signal {signum}, initiating graceful shutdown") + sys.exit(0) + + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + # Convert config dicts back to objects + logger.info("Initializing configuration objects") + provider_config_obj = HuggingFacePostTrainingConfig(**provider_config) + config_obj = TrainingConfig(**config) + + # Set device for this process first + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + logger.info(f"Process {local_rank} using device {device}") + + # Set environment variables for this process + # These are used by PyTorch's distributed module to coordinate between processes + os.environ["LOCAL_RANK"] = str(local_rank) # Unique rank for this process + os.environ["RANK"] = str(local_rank) # Global rank (same as local in our case) + os.environ["WORLD_SIZE"] = str(world_size) # Total number of processes + os.environ["MASTER_ADDR"] = "localhost" # Address of the main process + os.environ["MASTER_PORT"] = "29500" # Port for process communication + + # Initialize process group with NCCL backend + # NCCL is NVIDIA's library for multi-GPU communication + # This must be called after setting environment variables and device + if not dist.is_initialized(): + dist.init_process_group( + backend="nccl", + init_method="env://", + world_size=world_size, + rank=local_rank, + ) + logger.info(f"Initialized process group for rank {local_rank}") + dist.barrier() + + # Load dataset and tokenizer + train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj) + + # Calculate steps per epoch + if not config_obj.data_config: + raise ValueError("DataConfig is required for training") + steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size + + # Load model + logger.info("Loading the base model") + model_obj = self.load_model(model, device, provider_config_obj, peft_config) + dist.barrier() + + # Setup training arguments + training_args = self.setup_training_args( + config_obj, + provider_config_obj, + output_dir_path, + peft_config, + steps_per_epoch, + model_obj, + ) + + # Initialize trainer + logger.info("Initializing SFTTrainer") + trainer = SFTTrainer( + model=model_obj, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=peft_config, + args=training_args, + ) + + try: + # Train + logger.info("Starting training") + trainer.train() + logger.info("Training completed successfully") + + # Save final model if output directory is provided + if output_dir_path: # and local_rank == 0: + trainer.save_model(output_dir=output_dir_path) + # self.save_model(local_rank, model, model_obj, tokenizer, trainer, peft_config, output_dir_path) + + finally: + # Clean up resources + logger.info("Cleaning up resources") + if hasattr(trainer, "model"): + evacuate_model_from_device(trainer.model, device.type) + del trainer + dist.barrier() + dist.destroy_process_group() + gc.collect() + logger.info("Cleanup completed") + + async def load_dataset( + self, + model: str, + config: TrainingConfig, + provider_config: HuggingFacePostTrainingConfig, + ) -> tuple[Dataset, Dataset, AutoTokenizer]: + """Load and prepare the dataset for training. + + This method: + 1. Loads the dataset from the dataset provider + 2. Initializes the tokenizer for the model + 3. Creates and preprocesses the dataset + 4. Splits the dataset into train and validation sets + + Args: + model: The model identifier to load + config: Training configuration + provider_config: Provider-specific configuration + + Returns: + tuple: (train_dataset, eval_dataset, tokenizer) + + Raises: + ValueError: If dataset is missing required fields + RuntimeError: If dataset loading or tokenizer initialization fails + """ + # Validate data config + if not config.data_config: + raise ValueError("DataConfig is required for training") + + # Load dataset + logger.info(f"Loading dataset: {config.data_config.dataset_id}") + if not self.validate_dataset_format(self.data): + raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input") + logger.info(f"Loaded {len(self.data)} rows from dataset") + + # Initialize tokenizer + logger.info(f"Initializing tokenizer for model: {model}") + try: + tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config) + + # Set pad token to eos token if not present + # This is common for models that don't have a dedicated pad token + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + + # Set padding side to right for causal language modeling + # This ensures that padding tokens don't interfere with the model's ability + # to predict the next token in the sequence + tokenizer.padding_side = "right" + + # Set truncation side to right to keep the beginning of the sequence + # This is important for maintaining context and instruction format + tokenizer.truncation_side = "right" + + # Set model max length to match provider config + # This ensures consistent sequence lengths across the training process + tokenizer.model_max_length = provider_config.max_seq_length + + logger.info("Tokenizer initialized successfully") + except Exception as e: + raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e + + # Create and preprocess dataset + logger.info("Creating and preprocessing dataset") + try: + ds = self._create_dataset(self.data, config, provider_config) + ds = self._preprocess_dataset(ds, tokenizer, provider_config) + logger.info(f"Dataset created with {len(ds)} examples") + except Exception as e: + raise ValueError(f"Failed to create dataset: {str(e)}") from e + + # Split dataset + logger.info("Splitting dataset into train and validation sets") + train_val_split = ds.train_test_split(test_size=0.1, seed=42) + train_dataset = train_val_split["train"] + eval_dataset = train_val_split["test"] + logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples") + + return train_dataset, eval_dataset, tokenizer + + def load_model( + self, + model: str, + device: torch.device, + provider_config: HuggingFacePostTrainingConfig, + peft_config: LoraConfig | None = None, + ) -> AutoModelForCausalLM: + """Load and initialize the model for training. + + This method: + 1. Loads the model configuration + 2. Determines optimal dtype based on device capabilities + 3. Loads the model with specified dtype + 4. Applies LoRA if configured + 5. Moves the model to the specified device + + Args: + model: The model identifier to load + device: The device to load the model on + provider_config: Provider-specific configuration + peft_config: Optional LoRA configuration + + Returns: + AutoModelForCausalLM: The loaded and configured model + + Raises: + RuntimeError: If model loading fails + """ + logger.info("Loading the base model") + try: + model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config) + + # Determine optimal dtype based on device capabilities + if device.type == "cuda": + if torch.cuda.is_bf16_supported(): + torch_dtype = torch.bfloat16 + logger.info("Using bfloat16 precision (supported by device)") + else: + torch_dtype = torch.float16 + logger.info("Using float16 precision (bfloat16 not supported)") + else: + torch_dtype = torch.float32 + logger.info("Using float32 precision (non-CUDA device)") + + # Load model with specified dtype + model_obj = AutoModelForCausalLM.from_pretrained( + model, + torch_dtype=torch_dtype, + quantization_config=None, + config=model_config, + **provider_config.model_specific_config, + ) + logger.info("Base model loaded") + + # Apply LoRA if configured + if peft_config: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model_obj.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): # pylint: disable=unused-argument + output.requires_grad_(True) + + model_obj.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + logger.info("Applying LoRA configuration") + model_obj = LoraModel(model_obj, peft_config, "default") + logger.info("LoRA configuration applied") + else: + model_obj.gradient_checkpointing_enable() + + # Move model to device and return + # model_obj.to(device=device) + # logger.info(f"Model device: {next(model_obj.parameters()).device}") + return model_obj + except Exception as e: + raise RuntimeError(f"Failed to load model: {str(e)}") from e + + def setup_training_args( + self, + config: TrainingConfig, + provider_config: HuggingFacePostTrainingConfig, + output_dir_path: Path | None, + peft_config: LoraConfig | None, + steps_per_epoch: int, + model: AutoModelForCausalLM, + ) -> SFTConfig: + """Setup training arguments for distributed training. + + The FSDP (Fully Sharded Data Parallel) configuration is split into two parts: + 1. The fsdp_config dict which contains settings that are directly used by FSDP + 2. The fsdp string parameter which contains settings that are parsed by the trainer + + This split is necessary because the trainer only passes certain FSDP settings to the actual FSDP implementation. + """ + logger.info("Configuring training arguments for distributed training") + lr = 2e-5 + if config.optimizer_config: + lr = config.optimizer_config.lr + logger.info(f"Using custom learning rate: {lr}") + + # Validate data config + if not config.data_config: + raise ValueError("DataConfig is required for training") + data_config = config.data_config + + # Calculate steps + total_steps = steps_per_epoch * config.n_epochs + max_steps = min(config.max_steps_per_epoch, total_steps) + save_steps = max(1, steps_per_epoch // 5) # Save 5 times per epoch + logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch + + logger.info("Training configuration:") + logger.info(f"- Steps per epoch: {steps_per_epoch}") + logger.info(f"- Total steps: {total_steps}") + logger.info(f"- Max steps: {max_steps}") + logger.info(f"- Save steps: {save_steps}") + logger.info(f"- Logging steps: {logging_steps}") + + # Calculate optimal batch size based on available GPUs + if torch.cuda.is_available(): + num_gpus = torch.cuda.device_count() + effective_batch_size = max(1, data_config.batch_size // num_gpus) + logger.info(f"Using {effective_batch_size} batch size per GPU (total batch size: {data_config.batch_size})") + else: + effective_batch_size = data_config.batch_size + + # Determine optimal precision settings + if torch.cuda.is_available(): + fp16 = not torch.cuda.is_bf16_supported() + bf16 = torch.cuda.is_bf16_supported() + logger.info(f"Using {'bfloat16' if bf16 else 'float16'} precision") + else: + fp16 = False + bf16 = False + logger.info("Using float32 precision") + + # Configure save strategy + save_strategy = "no" + if output_dir_path: + save_strategy = "steps" # Save by steps for more frequent saves + logger.info(f"Will save checkpoints to {output_dir_path}") + + # FSDP Configuration - Part 1: Direct FSDP settings + # These settings are passed directly to the FSDP implementation + fsdp_config = { + # Enable CPU RAM efficient loading to reduce GPU memory usage during model loading + "cpu_ram_efficient_loading": True, + # Specify which transformer layer class to wrap with FSDP + # This is crucial for proper sharding of the model + "transformer_layer_cls_to_wrap": [model._no_split_modules[0]], + # Use full sharding strategy for maximum memory efficiency + "sharding_strategy": "FULL_SHARD", + # Disable forward prefetch to reduce memory usage + "forward_prefetch": False, + # Limit all-gather operations to reduce memory spikes + "limit_all_gathers": True, + # Enable parameter offloading to CPU to reduce GPU memory usage + "offload_param": True, + # Ensure module states are synchronized across processes + "sync_module_states": True, + # Enable verbose logging for debugging + "verbose": True, + # State dict settings for better checkpoint handling + "state_dict_type": "FULL_STATE_DICT", + "state_dict_config": { + "offload_to_cpu": True, # Offload state dict to CPU during saving + "rank0_only": True, # Only rank 0 saves the state dict + }, + } + + # Add LoRA-specific or full model FSDP settings + if peft_config: + # LoRA configuration - less aggressive sharding since LoRA is already memory efficient + fsdp_config.update( + { + "backward_prefetch": "backward_post", # Prefetch after backward pass + "activation_checkpointing": False, # No need for activation checkpointing with LoRA + "use_orig_params": False, # Don't use original parameters for LoRA + } + ) + else: + # Full model configuration - more aggressive memory optimization + fsdp_config.update( + { + "backward_prefetch": "backward_pre", # Prefetch before backward pass + "activation_checkpointing": False, # Use FSDP's built-in activation checkpointing + "use_orig_params": True, # Use original parameters for full model + } + ) + + # Set up training config + training_config = SFTConfig( + # FSDP Configuration - Part 2: Trainer-level FSDP settings + # These settings are parsed by the trainer and passed to FSDP + fsdp="full_shard auto_wrap offload", # Enable full sharding, auto wrapping, and offloading + # Pass the direct FSDP settings + fsdp_config=fsdp_config, + # Enable gradient checkpointing for memory efficiency + gradient_checkpointing=provider_config.gradient_checkpointing, + gradient_checkpointing_kwargs={ + "use_reentrant": False, # Disable reentrant checkpointing for better memory efficiency + "preserve_rng_state": False, # Don't preserve RNG state to save memory + } + if provider_config.gradient_checkpointing + else None, + # Disable torch.compile as it can interfere with FSDP + torch_compile=False, + # Training parameters + max_steps=max_steps, + dataloader_num_workers=1, # Single worker to avoid memory issues + dataloader_pin_memory=False, # Disable pin memory to reduce memory usage + optim="adamw_torch", # Use PyTorch's AdamW implementation + output_dir=str(output_dir_path) if output_dir_path is not None else None, + num_train_epochs=config.n_epochs, + per_device_train_batch_size=effective_batch_size, + fp16=fp16, + bf16=bf16, + eval_strategy="no", + use_cpu=False, + save_strategy=save_strategy, + save_steps=save_steps, + save_total_limit=3, # Keep last 3 checkpoints + save_safetensors=True, + save_only_model=True, # Only save model state, not optimizer state for FSDP compatibility + report_to="none", + max_seq_length=provider_config.max_seq_length, + gradient_accumulation_steps=config.gradient_accumulation_steps, + learning_rate=lr, + lr_scheduler_type="cosine", + warmup_steps=25, + warmup_ratio=provider_config.warmup_ratio, + weight_decay=provider_config.weight_decay, + remove_unused_columns=False, + dataset_text_field="text", + load_best_model_at_end=False, + metric_for_best_model="eval_loss", + packing=False, + greater_is_better=False, + logging_steps=logging_steps, + logging_first_step=True, + logging_dir=str(output_dir_path / "logs") if output_dir_path else None, + logging_nan_inf_filter=True, + overwrite_output_dir=True, + ) + + return training_config + + def save_model( + self, + local_rank: int, + model_path: str, + model_obj: AutoModelForCausalLM, + tokenizer: PreTrainedTokenizer, + trainer: SFTTrainer, + peft_config: LoraConfig | None, + output_dir_path: Path, + ) -> None: + """Save the trained model with proper FSDP handling. + + This method handles saving both LoRA and full models with proper FSDP state dict handling. + For LoRA models, it merges the weights with the base model before saving. + + Args: + local_rank: Local rank of this process + model_path: Path to the original model + model_obj: The model to save + tokenizer: Tokenizer to save + trainer: The trainer instance + peft_config: Optional LoRA configuration + output_dir_path: Path to save the model + + Raises: + RuntimeError: If model saving fails + """ + logger.info("Saving final model") + model_obj.config.use_cache = True + save_path = output_dir_path / "final_model" + logger.info(f"Saving model to {save_path}") + + # Ensure all processes are ready to save + dist.barrier() + + try: + if peft_config: + logger.info("Merging LoRA weights with base model") + # Get full state dict with FSDP handling + sd_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FullyShardedDataParallel.state_dict_type(model_obj, StateDictType.FULL_STATE_DICT, sd_config): + state = model_obj.state_dict() + + if local_rank == 0: + try: + # Load a CPU copy of the base model for merging + logger.info("Loading CPU copy of base model for merging") + model_copy = AutoModelForCausalLM.from_pretrained( + model_path, + device_map="cpu", # Ensure CPU loading + torch_dtype=torch.float32, # Use float32 for better precision during merging + ) + model_copy = LoraModel(model_copy, peft_config, "default") + + # Load the trained state and merge + logger.info("Loading trained state and merging weights") + model_copy.load_state_dict(state) + merged_model = model_copy.merge_and_unload(progressbar=True) + + # Save the merged model and tokenizer + logger.info("Saving merged model and tokenizer") + merged_model.save_pretrained(save_path, safe_serialization=True) + tokenizer.save_pretrained(save_path) + + # Clean up + del model_copy + logger.info("Successfully saved merged LoRA model and tokenizer") + except Exception as e: + logger.error(f"Failed to save merged LoRA model: {str(e)}") + raise + else: + logger.info("Saving full model with FSDP") + # For full model, use FSDP's state dict handling + if local_rank == 0: + try: + model_obj.save_pretrained(save_path, safe_serialization=True) + tokenizer.save_pretrained(save_path) + logger.info("Successfully saved full model and tokenizer") + except Exception as e: + logger.error(f"Failed to save full model: {str(e)}") + raise + finally: + # Ensure all processes wait for saving to complete + dist.barrier() + logger.info("Model saving completed") + + async def train( + self, + model: str, + output_dir: str | None, + job_uuid: str, + lora_config: LoraFinetuningConfig, + config: TrainingConfig, + provider_config: HuggingFacePostTrainingConfig, + ) -> tuple[dict[str, Any], list[Checkpoint] | None]: + """Train a model using HuggingFace's SFTTrainer with distributed training.""" + if provider_config.distributed_backend != "fsdp": + raise RuntimeError("Must enable FSDP as distributed backend to use this recipe") + + # Configure NCCL logging based on debug settings + configure_nccl_logging(self.enable_nccl_debug, self.nccl_debug_subsys) + + # Get local rank and world size from environment variables + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + + output_dir_path = None + if output_dir: + output_dir_path = Path(output_dir) + + # Track memory stats on first GPU + memory_stats = { + "initial": get_memory_stats(torch.device("cuda:0")), + "after_training": None, + "final": None, + } + + # Configure LoRA if specified + peft_config = None + if lora_config: + peft_config = LoraConfig( + lora_alpha=lora_config.alpha, + lora_dropout=0.1, + r=lora_config.rank, + bias="none", + task_type="CAUSAL_LM", + target_modules=lora_config.lora_attn_modules, + ) + + # Validate data config + if not config.data_config: + raise ValueError("DataConfig is required for training") + + try: + # Run training for this process + await self._run_training( + model=model, + provider_config=provider_config.model_dump(), + peft_config=peft_config, + config=config.model_dump(), + output_dir_path=output_dir_path, + local_rank=local_rank, + world_size=world_size, + ) + + memory_stats["after_training"] = get_memory_stats(torch.device("cuda:0")) + + # Only create checkpoint on rank 0 + checkpoints = None + if output_dir_path and local_rank == 0: + checkpoint = Checkpoint( + identifier=f"{model}-sft-{config.n_epochs}", + created_at=datetime.now(timezone.utc), + epoch=config.n_epochs, + post_training_job_id=job_uuid, + path=str(output_dir_path / "merged_model"), + ) + checkpoints = [checkpoint] + + return memory_stats, checkpoints + finally: + memory_stats["final"] = get_memory_stats(torch.device("cuda:0")) + gc.collect() diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py index b6d13b029..90eb143b3 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py @@ -37,8 +37,6 @@ from transformers import ( ) from trl import SFTConfig, SFTTrainer -from llama_stack.apis.datasetio import DatasetIO -from llama_stack.apis.datasets import Datasets from llama_stack.apis.post_training import ( Checkpoint, DataConfig, @@ -136,11 +134,9 @@ class HFFinetuningSingleDevice: def __init__( self, job_uuid: str, - datasetio_api: DatasetIO, - datasets_api: Datasets, + data: list[dict[str, Any]], ): - self.datasetio_api = datasetio_api - self.datasets_api = datasets_api + self.data = data self.job_uuid = job_uuid def validate_dataset_format(self, rows: list[dict]) -> bool: @@ -262,19 +258,6 @@ class HFFinetuningSingleDevice: remove_columns=ds.column_names, ) - async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]: - """Load dataset from llama stack dataset provider""" - try: - all_rows = await self.datasetio_api.iterrows( - dataset_id=dataset_id, - limit=-1, - ) - if not isinstance(all_rows.data, list): - raise RuntimeError("Expected dataset data to be a list") - return all_rows.data - except Exception as e: - raise RuntimeError(f"Failed to load dataset: {str(e)}") from e - def _run_training_sync( self, model: str, @@ -327,10 +310,9 @@ class HFFinetuningSingleDevice: # Load dataset logger.info(f"Loading dataset: {config.data_config.dataset_id}") - rows = await self._setup_data(config.data_config.dataset_id) - if not self.validate_dataset_format(rows): + if not self.validate_dataset_format(self.data): raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input") - logger.info(f"Loaded {len(rows)} rows from dataset") + logger.info(f"Loaded {len(self.data)} rows from dataset") # Initialize tokenizer logger.info(f"Initializing tokenizer for model: {model}") @@ -362,7 +344,7 @@ class HFFinetuningSingleDevice: # Create and preprocess dataset logger.info("Creating and preprocessing dataset") try: - ds = self._create_dataset(rows, config, provider_config) + ds = self._create_dataset(self.data, config, provider_config) ds = self._preprocess_dataset(ds, tokenizer, provider_config) logger.info(f"Dataset created with {len(ds)} examples") except Exception as e: diff --git a/llama_stack/providers/utils/scheduler.py b/llama_stack/providers/utils/scheduler.py index 845ab1f02..e661d2bf6 100644 --- a/llama_stack/providers/utils/scheduler.py +++ b/llama_stack/providers/utils/scheduler.py @@ -7,10 +7,12 @@ import abc import asyncio import functools +import multiprocessing import threading from collections.abc import Callable, Coroutine, Iterable from datetime import datetime, timezone from enum import Enum +from pathlib import Path from typing import Any, TypeAlias from pydantic import BaseModel @@ -54,7 +56,7 @@ _COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed} class Job: - def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler): + def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler | None): super().__init__() self.id = job_id self._type = job_type @@ -62,9 +64,38 @@ class Job: self._artifacts: list[JobArtifact] = [] self._logs: list[LogMessage] = [] self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)] + self._child_processes: list[multiprocessing.Process] = [] + self._world_size: int = 1 # Number of processes for distributed training + self.run_args: dict[str, Any] = {} # Dictionary to store run arguments @property - def handler(self) -> JobHandler: + def world_size(self) -> int: + return self._world_size + + @world_size.setter + def world_size(self, size: int) -> None: + self._world_size = size + + def add_child_process(self, process: multiprocessing.Process) -> None: + self._child_processes.append(process) + + def cancel(self) -> None: + """Cancel the job and all its child processes.""" + for process in self._child_processes: + if process.is_alive(): + process.terminate() + process.join(timeout=5) + self.status = JobStatus.failed + + def cleanup(self) -> None: + """Clean up any remaining child processes.""" + for process in self._child_processes: + if process.is_alive(): + process.terminate() + process.join(timeout=5) + + @property + def handler(self) -> JobHandler | None: return self._handler @property @@ -111,10 +142,6 @@ class Job: def append_log(self, message: LogMessage) -> None: self._logs.append(message) - # TODO: implement - def cancel(self) -> None: - raise NotImplementedError - class _SchedulerBackend(abc.ABC): @abc.abstractmethod @@ -148,8 +175,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend): def __init__(self, timeout: int = 5): self._timeout = timeout self._loop = asyncio.new_event_loop() - # There may be performance implications of using threads due to Python - # GIL; may need to measure if it's a real problem though self._thread = threading.Thread(target=self._run_loop, daemon=True) self._thread.start() @@ -158,7 +183,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend): self._loop.run_forever() # When stopping the loop, give tasks a chance to finish - # TODO: should we explicitly inform jobs of pending stoppage? for task in asyncio.all_tasks(self._loop): self._loop.run_until_complete(task) self._loop.close() @@ -167,7 +191,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend): self._loop.call_soon_threadsafe(self._loop.stop) self._thread.join() - # TODO: decouple scheduling and running the job def schedule( self, job: Job, @@ -179,6 +202,7 @@ class _NaiveSchedulerBackend(_SchedulerBackend): try: job.status = JobStatus.running await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb) + job.status = JobStatus.completed except Exception as e: on_log_message_cb(str(e)) job.status = JobStatus.failed @@ -196,8 +220,183 @@ class _NaiveSchedulerBackend(_SchedulerBackend): pass +class DistributedJobScheduler(_SchedulerBackend): + """A scheduler backend that supports distributed training jobs. + + This scheduler uses torchrun to handle distributed training process spawning and coordination. + torchrun automatically handles: + - Process spawning + - Environment variable setup + - Process group initialization + - Error handling and process cleanup + """ + + def __init__(self, timeout: int = 5): + self._timeout = timeout + self._loop = asyncio.new_event_loop() + self._thread = threading.Thread(target=self._run_loop, daemon=True) + self._thread.start() + self._active_jobs: dict[JobID, asyncio.subprocess.Process] = {} + + def _run_loop(self) -> None: + asyncio.set_event_loop(self._loop) + self._loop.run_forever() + + # When stopping the loop, give tasks a chance to finish + for task in asyncio.all_tasks(self._loop): + self._loop.run_until_complete(task) + self._loop.close() + + async def shutdown(self) -> None: + # Clean up any remaining processes + for process in self._active_jobs.values(): + if process.returncode is None: # Process is still running + process.terminate() + try: + await asyncio.wait_for(process.wait(), timeout=5) + except asyncio.TimeoutError: + process.kill() + await process.wait() + + self._loop.call_soon_threadsafe(self._loop.stop) + self._thread.join() + + def schedule( + self, + job: Job, + on_log_message_cb: Callable[[str], None], + on_status_change_cb: Callable[[JobStatus], None], + on_artifact_collected_cb: Callable[[JobArtifact], None], + ) -> None: + async def do(): + try: + job.status = JobStatus.running + + # If this is a distributed training job, use torchrun + if job.world_size > 1: + # Find the path to finetune_handler.py + from llama_stack.providers.inline.post_training.huggingface import finetune_handler + + handler_path = Path(finetune_handler.__file__) + + # Prepare arguments for the handler script + args = [ + "torchrun", + f"--nproc_per_node={job.world_size}", + "--master_addr=localhost", + "--master_port=29500", + str(handler_path), + ] + + # Add arguments from the job.run_args dictionary as proper command-line flags + for arg_name, arg_value in job.run_args.items(): + # Skip world_size as we've already handled it + if arg_name == "world_size": + continue + + if arg_value is not None: + # Handle boolean flags + if isinstance(arg_value, bool): + if arg_value: + args.append(f"--{arg_name}") + else: + # For non-boolean values, we add the argument as a separate flag and value + args.append(f"--{arg_name}") + args.append(str(arg_value)) + + # Launch torchrun using asyncio + on_log_message_cb(f"Launching distributed training with {job.world_size} processes") + on_log_message_cb(f"Command: {' '.join(args)}") + + # Make sure we capture stdout and stderr + process = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + + # Store process for this job + self._active_jobs[job.id] = process + + # Start monitoring in a separate task so we don't block + asyncio.create_task( + self._monitor_process(job, process, None, on_log_message_cb, on_status_change_cb) + ) + else: + # For single-device training, call the handler directly if provided + if job.handler: + await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb) + job.status = JobStatus.completed + else: + on_log_message_cb("No handler function provided for single-device training") + job.status = JobStatus.failed + except Exception as e: + on_log_message_cb(str(e)) + job.status = JobStatus.failed + logger.exception(f"Job {job.id} failed.") + + asyncio.run_coroutine_threadsafe(do(), self._loop) + + async def _monitor_process( + self, + job: Job, + process: asyncio.subprocess.Process, + script_path: Path | None, + on_log_message_cb: Callable[[str], None], + on_status_change_cb: Callable[[JobStatus], None], + ) -> None: + """Monitor a process until completion.""" + try: + # Stream output from the process if stdout is available + if process.stdout is not None: + while True: + line = await process.stdout.readline() + if not line and process.returncode is not None: + break + if line: + on_log_message_cb(line.decode().strip()) + else: + # If stdout is not available, just wait for the process to complete + on_log_message_cb("Process stdout not available, waiting for completion") + await process.wait() + + # Wait for process to complete if not already done + if process.returncode is None: + await process.wait() + + # Check if process failed + if process.returncode != 0: + on_log_message_cb(f"Training failed with return code {process.returncode}") + job.status = JobStatus.failed + else: + on_status_change_cb(JobStatus.completed) + job.status = JobStatus.completed + except Exception as e: + on_log_message_cb(f"Error monitoring process: {str(e)}") + job.status = JobStatus.failed + logger.exception(f"Error monitoring process for job {job.id}") + finally: + # Clean up temporary files + if script_path and script_path.exists(): + script_path.unlink() + + # Remove from active jobs + if job.id in self._active_jobs: + del self._active_jobs[job.id] + + def on_log_message_cb(self, job: Job, message: LogMessage) -> None: + pass + + def on_status_change_cb(self, job: Job, status: JobStatus) -> None: + pass + + def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None: + pass + + _BACKENDS = { "naive": _NaiveSchedulerBackend, + "distributed": DistributedJobScheduler, } @@ -230,11 +429,18 @@ class Scheduler: job.register_artifact(artifact) self._backend.on_artifact_collected_cb(job, artifact) - def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID: + def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler | None, run_params: dict[str, Any]) -> JobID: job = Job(type_, job_id, handler) if job.id in self._jobs: raise ValueError(f"Job {job.id} already exists") + # Set world size if provided + if "world_size" in run_params: + job.world_size = run_params["world_size"] + + # Store all run parameters in the job's run_args dictionary + job.run_args = run_params + self._jobs[job.id] = job job.status = JobStatus.scheduled self._backend.schedule( diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml index 85d5c813b..9acaf4aea 100644 --- a/llama_stack/templates/ollama/run-with-safety.yaml +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -100,6 +100,7 @@ providers: checkpoint_format: huggingface distributed_backend: null device: cpu + recipe: single tool_runtime: - provider_id: brave-search provider_type: remote::brave-search diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index 2d10a99a4..965777c35 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -98,6 +98,7 @@ providers: checkpoint_format: huggingface distributed_backend: null device: cpu + recipe: single tool_runtime: - provider_id: brave-search provider_type: remote::brave-search