diff --git a/llama_stack/providers/inline/post_training/huggingface/finetune_handler.py b/llama_stack/providers/inline/post_training/huggingface/finetune_handler.py new file mode 100755 index 000000000..d993491f5 --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface/finetune_handler.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import asyncio +import json +import os +from typing import Any + +from llama_stack.apis.post_training import TrainingConfig +from llama_stack.providers.inline.post_training.huggingface.config import HuggingFacePostTrainingConfig +from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_multi_device import ( + HFFinetuningMultiDevice, +) +from llama_stack.providers.utils.scheduler import JobStatus + + +async def train( + job_uuid, + model, + checkpoint_dir, + training_config, + provider_config, + algorithm_config, + data, + enable_nccl_debug=False, + nccl_debug_subsys="NONE", +): + """Handler function for HuggingFace training that can be called by torchrun. + + This is extracted from the supervised_fine_tune method in the HuggingFacePostTrainingImpl class. + It follows the same flow, but is designed to be called directly from a script. + + Args: + job_uuid: Unique ID for this job + model: Model to train + checkpoint_dir: Directory to save checkpoints to + training_config: Training configuration + provider_config: Provider configuration + algorithm_config: Algorithm configuration + data: the dataset rows to be processed + enable_nccl_debug: Whether to enable NCCL debugging + nccl_debug_subsys: NCCL subsystem to debug + """ + # Get rank information when running distributed + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + + parsed_data: list[dict[str, Any]] = json.loads(data) + + # Set up callback functions with rank information + def on_log_message_cb(msg): + print(f"[RANK {local_rank}] {msg}", flush=True) + + def on_status_change_cb(status): + print(f"[RANK {local_rank}] Status: {status}", flush=True) + + def on_artifact_collected_cb(artifact): + print(f"[RANK {local_rank}] Artifact: {artifact}", flush=True) + + on_log_message_cb("Starting HF finetuning") + + recipe_obj = HFFinetuningMultiDevice( + job_uuid=job_uuid, enable_nccl_debug=enable_nccl_debug, nccl_debug_subsys=nccl_debug_subsys, data=parsed_data + ) + + resources_allocated, checkpoints = await recipe_obj.train( + model=model, + output_dir=checkpoint_dir, + job_uuid=job_uuid, + lora_config=algorithm_config, + config=training_config, + provider_config=provider_config, + ) + + def resources_stats_to_artifact(resources_stats): + return { + "type": "resources_stats", + "name": "resources_stats", + "metadata": resources_stats, + } + + def checkpoint_to_artifact(checkpoint): + return { + "type": "checkpoint", + "name": checkpoint.identifier, + "uri": checkpoint.path, + "metadata": dict(checkpoint), + } + + on_artifact_collected_cb(resources_stats_to_artifact(resources_allocated)) + if checkpoints: + for checkpoint in checkpoints: + artifact = checkpoint_to_artifact(checkpoint) + on_artifact_collected_cb(artifact) + + on_status_change_cb(JobStatus.completed) + on_log_message_cb("HF finetuning completed") + + +async def main(): + parser = argparse.ArgumentParser(description="Run HuggingFace training with torchrun.") + parser.add_argument("--job_uuid", type=str, required=True, help="Job UUID") + parser.add_argument("--model", type=str, required=True, help="Model to use") + parser.add_argument("--checkpoint_dir", type=str, help="Directory to save checkpoints") + parser.add_argument("--training_config", type=str, required=True, help="Training config JSON") + parser.add_argument("--provider_config", type=str, required=True, help="Provider config JSON") + parser.add_argument("--algorithm_config", type=str, help="Algorithm config JSON") + parser.add_argument("--enable_nccl_debug", action="store_true", help="Enable NCCL debugging") + parser.add_argument("--nccl_debug_subsys", type=str, default="NONE", help="NCCL subsystem to debug") + parser.add_argument("--data", type=str, required=True) + + args = parser.parse_args() + + # Parse JSON configs + try: + training_config = TrainingConfig.model_validate_json(args.training_config) + except Exception as e: + print(f"Error parsing training_config: {e}") + print(f"Received: {args.training_config}") + raise + + try: + provider_config = HuggingFacePostTrainingConfig.model_validate_json(args.provider_config) + except Exception as e: + print(f"Error parsing provider_config: {e}") + print(f"Received: {args.provider_config}") + raise + + algorithm_config = None + if args.algorithm_config: + try: + algorithm_config = json.loads(args.algorithm_config) + except json.JSONDecodeError as e: + print(f"Error parsing algorithm_config: {e}") + print(f"Received: {args.algorithm_config}") + raise + + # In a real implementation, you would get these from somewhere + # For now, we'll pass None and handle it in the train function + datasetio_api = None + datasets_api = None + + # Print arguments for debugging + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if local_rank == 0: # Only the main process prints + print("Starting training with arguments:") + print(f" job_uuid: {args.job_uuid}") + print(f" model: {args.model}") + print(f" checkpoint_dir: {args.checkpoint_dir}") + print(f" enable_nccl_debug: {args.enable_nccl_debug}") + print(f" nccl_debug_subsys: {args.nccl_debug_subsys}") + + await train( + job_uuid=args.job_uuid, + model=args.model, + checkpoint_dir=args.checkpoint_dir, + training_config=training_config, + provider_config=provider_config, + algorithm_config=algorithm_config, + datasetio_api=datasetio_api, + datasets_api=datasets_api, + enable_nccl_debug=args.enable_nccl_debug, + nccl_debug_subsys=args.nccl_debug_subsys, + data=args.data, + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/llama_stack/providers/inline/post_training/huggingface/post_training.py b/llama_stack/providers/inline/post_training/huggingface/post_training.py index 33bce882c..7f0d5d1a3 100644 --- a/llama_stack/providers/inline/post_training/huggingface/post_training.py +++ b/llama_stack/providers/inline/post_training/huggingface/post_training.py @@ -3,6 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json from enum import Enum from typing import Any @@ -81,43 +82,130 @@ class HuggingFacePostTrainingImpl: checkpoint_dir: str | None = None, algorithm_config: AlgorithmConfig | None = None, ) -> PostTrainingJob: - async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb): - on_log_message_cb("Starting HF finetuning") + from collections.abc import Callable, Coroutine + from typing import Any - recipe = HFFinetuningSingleDevice( - job_uuid=job_uuid, - datasetio_api=self.datasetio_api, - datasets_api=self.datasets_api, - ) - if self.config.recipe == "multi": - recipe = HFFinetuningMultiDevice( + # Type for the handler: async fn taking 3 Any args, returns Awaitable[None] + handler: ( + Callable[ + [Callable[[str], None], Callable[[SchedulerJobStatus], None], Callable[[JobArtifact], None]], + Coroutine[Any, Any, None], + ] + | None + ) = None + + # Determine world size for distributed training + world_size = getattr(self.config, "world_size", 1) + + # Choose the backend and recipe based on world size + if world_size > 1: + recipe = "multi" + + # Create parameters for the handler script + run_params = { + "job_uuid": job_uuid, + "model": model, + "world_size": world_size, + "recipe": recipe, + } + + # Add optional parameters + if checkpoint_dir is not None: + run_params["checkpoint_dir"] = checkpoint_dir + + if training_config is not None: + run_params["training_config"] = training_config.model_dump_json() + + if algorithm_config is not None: + run_params["algorithm_config"] = algorithm_config.model_dump_json() + + # Add provider-specific configuration + run_params["provider_config"] = self.config.model_dump_json() + + # Add NCCL debug settings if present + if hasattr(self.config, "enable_nccl_debug"): + run_params["enable_nccl_debug"] = self.config.enable_nccl_debug + + if hasattr(self.config, "nccl_debug_subsys"): + run_params["nccl_debug_subsys"] = self.config.nccl_debug_subsys + + # Initialize the scheduler with the distributed backend + self._scheduler = Scheduler(backend="distributed") + else: + self._scheduler = Scheduler(backend="naive") + + # TODO: this can probably be cleaner + # Single-device training path + # Define a handler for single-device training + async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb): + on_log_message_cb("Starting HF finetuning") + + recipe = HFFinetuningSingleDevice( job_uuid=job_uuid, datasetio_api=self.datasetio_api, datasets_api=self.datasets_api, - enable_nccl_debug=self.config.enable_nccl_debug, - nccl_debug_subsys=self.config.nccl_debug_subsys, + ) + if self.config.recipe == "multi": + recipe = HFFinetuningMultiDevice( + job_uuid=job_uuid, + datasetio_api=self.datasetio_api, + datasets_api=self.datasets_api, + enable_nccl_debug=getattr(self.config, "enable_nccl_debug", False), + nccl_debug_subsys=getattr(self.config, "nccl_debug_subsys", "NONE"), + ) + + resources_allocated, checkpoints = await recipe.train( + model=model, + output_dir=checkpoint_dir, + job_uuid=job_uuid, + lora_config=algorithm_config, + config=training_config, + provider_config=self.config, ) - resources_allocated, checkpoints = await recipe.train( - model=model, - output_dir=checkpoint_dir, - job_uuid=job_uuid, - lora_config=algorithm_config, - config=training_config, - provider_config=self.config, + on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated)) + if checkpoints: + for checkpoint in checkpoints: + artifact = self._checkpoint_to_artifact(checkpoint) + on_artifact_collected_cb(artifact) + + on_status_change_cb(SchedulerJobStatus.completed) + on_log_message_cb("HF finetuning completed") + + assert training_config.data_config is not None + data = self._setup_data(dataset_id=training_config.data_config.dataset_id) + + json_data = json.dumps(data) + + run_params["data"] = json_data + + # Schedule the job with the regular scheduler and the handler + job_id = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler, run_params) + + return PostTrainingJob(job_uuid=job_id) + + async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]: + """Load dataset from llama stack dataset provider. + + Args: + dataset_id: ID of the dataset to load + + Returns: + list: List of dataset rows + + Raises: + RuntimeError: If dataset loading fails + """ + try: + all_rows = await self.datasetio_api.iterrows( + dataset_id=dataset_id, + limit=-1, ) - - on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated)) - if checkpoints: - for checkpoint in checkpoints: - artifact = self._checkpoint_to_artifact(checkpoint) - on_artifact_collected_cb(artifact) - - on_status_change_cb(SchedulerJobStatus.completed) - on_log_message_cb("HF finetuning completed") - - job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler) - return PostTrainingJob(job_uuid=job_uuid) + if not isinstance(all_rows.data, list): + raise RuntimeError("Expected dataset data to be a list") + return all_rows.data + except Exception as e: + raise RuntimeError(f"Failed to load dataset: {str(e)}") from e async def preference_optimize( self, diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_multi_device.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_multi_device.py index cf1da1d33..ab888c42c 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_multi_device.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_multi_device.py @@ -75,8 +75,6 @@ from transformers import ( ) from trl import SFTConfig, SFTTrainer -from llama_stack.apis.datasetio import DatasetIO -from llama_stack.apis.datasets import Datasets from llama_stack.apis.post_training import ( Checkpoint, DataConfig, @@ -191,8 +189,7 @@ class HFFinetuningMultiDevice: def __init__( self, job_uuid: str, - datasetio_api: DatasetIO, - datasets_api: Datasets, + data: list[dict[str, Any]], enable_nccl_debug: bool = False, nccl_debug_subsys: str = "NONE", ): @@ -203,8 +200,7 @@ class HFFinetuningMultiDevice: datasetio_api: API for dataset I/O operations datasets_api: API for dataset management """ - self.datasetio_api = datasetio_api - self.datasets_api = datasets_api + self.data = data self.job_uuid = job_uuid self.enable_nccl_debug = enable_nccl_debug self.nccl_debug_subsys = nccl_debug_subsys @@ -408,29 +404,6 @@ class HFFinetuningMultiDevice: num_proc=1, # Single process to avoid issues ) - async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]: - """Load dataset from llama stack dataset provider. - - Args: - dataset_id: ID of the dataset to load - - Returns: - list: List of dataset rows - - Raises: - RuntimeError: If dataset loading fails - """ - try: - all_rows = await self.datasetio_api.iterrows( - dataset_id=dataset_id, - limit=-1, - ) - if not isinstance(all_rows.data, list): - raise RuntimeError("Expected dataset data to be a list") - return all_rows.data - except Exception as e: - raise RuntimeError(f"Failed to load dataset: {str(e)}") from e - def _run_training_sync( self, local_rank: int, # First parameter must be local_rank for spawn @@ -627,10 +600,9 @@ class HFFinetuningMultiDevice: # Load dataset logger.info(f"Loading dataset: {config.data_config.dataset_id}") - rows = await self._setup_data(config.data_config.dataset_id) - if not self.validate_dataset_format(rows): + if not self.validate_dataset_format(self.data): raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input") - logger.info(f"Loaded {len(rows)} rows from dataset") + logger.info(f"Loaded {len(self.data)} rows from dataset") # Initialize tokenizer logger.info(f"Initializing tokenizer for model: {model}") @@ -662,7 +634,7 @@ class HFFinetuningMultiDevice: # Create and preprocess dataset logger.info("Creating and preprocessing dataset") try: - ds = self._create_dataset(rows, config, provider_config) + ds = self._create_dataset(self.data, config, provider_config) ds = self._preprocess_dataset(ds, tokenizer, provider_config) logger.info(f"Dataset created with {len(ds)} examples") except Exception as e: @@ -1021,37 +993,16 @@ class HFFinetuningMultiDevice: config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig, ) -> tuple[dict[str, Any], list[Checkpoint] | None]: - """Train a model using HuggingFace's SFTTrainer with distributed training. - - The distributed training setup works as follows: - 1. Parse the device list to determine number of GPUs - 2. Use torch.multiprocessing.spawn to launch one process per GPU - 3. Each process runs _run_training_sync with a unique rank - 4. The processes coordinate through NCCL backend - 5. FSDP handles model sharding across GPUs - 6. Only rank 0 handles saving checkpoints and logging - - Args: - model: The model identifier to load - output_dir: Optional directory to save checkpoints - job_uuid: Unique identifier for this training job - lora_config: LoRA configuration for parameter-efficient fine-tuning - config: General training configuration - provider_config: Provider-specific configuration - Returns: - tuple: (memory_stats, checkpoints) - """ - + """Train a model using HuggingFace's SFTTrainer with distributed training.""" if provider_config.distributed_backend != "fsdp": raise RuntimeError("Must enable FSDP as distributed backend to use this recipe") # Configure NCCL logging based on debug settings configure_nccl_logging(self.enable_nccl_debug, self.nccl_debug_subsys) - # Parse device list to determine number of GPUs - devices = [d.strip() for d in provider_config.device.split(",")] - world_size = len(devices) - logger.info(f"Using {world_size} devices: {devices}") + # Get local rank and world size from environment variables + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) output_dir_path = None if output_dir: @@ -1081,32 +1032,22 @@ class HFFinetuningMultiDevice: raise ValueError("DataConfig is required for training") try: - # Launch distributed training processes - # torch.multiprocessing.spawn will: - # 1. Create world_size number of processes - # 2. Call _run_training_sync for each process - # 3. Pass unique local_rank to each process - # 4. Handle process coordination and cleanup - logger.info("Starting distributed training processes") - torch.multiprocessing.spawn( - self._run_training_sync, - args=( - world_size, - model, - provider_config.model_dump(), - peft_config, - config.model_dump(), - output_dir_path, - ), - nprocs=world_size, - join=True, # Wait for all processes to complete + # Run training for this process + await self._run_training( + model=model, + provider_config=provider_config.model_dump(), + peft_config=peft_config, + config=config.model_dump(), + output_dir_path=output_dir_path, + local_rank=local_rank, + world_size=world_size, ) memory_stats["after_training"] = get_memory_stats(torch.device("cuda:0")) - # Create checkpoint on rank 0 + # Only create checkpoint on rank 0 checkpoints = None - if output_dir_path: + if output_dir_path and local_rank == 0: checkpoint = Checkpoint( identifier=f"{model}-sft-{config.n_epochs}", created_at=datetime.now(timezone.utc), diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py index b6d13b029..90eb143b3 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py @@ -37,8 +37,6 @@ from transformers import ( ) from trl import SFTConfig, SFTTrainer -from llama_stack.apis.datasetio import DatasetIO -from llama_stack.apis.datasets import Datasets from llama_stack.apis.post_training import ( Checkpoint, DataConfig, @@ -136,11 +134,9 @@ class HFFinetuningSingleDevice: def __init__( self, job_uuid: str, - datasetio_api: DatasetIO, - datasets_api: Datasets, + data: list[dict[str, Any]], ): - self.datasetio_api = datasetio_api - self.datasets_api = datasets_api + self.data = data self.job_uuid = job_uuid def validate_dataset_format(self, rows: list[dict]) -> bool: @@ -262,19 +258,6 @@ class HFFinetuningSingleDevice: remove_columns=ds.column_names, ) - async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]: - """Load dataset from llama stack dataset provider""" - try: - all_rows = await self.datasetio_api.iterrows( - dataset_id=dataset_id, - limit=-1, - ) - if not isinstance(all_rows.data, list): - raise RuntimeError("Expected dataset data to be a list") - return all_rows.data - except Exception as e: - raise RuntimeError(f"Failed to load dataset: {str(e)}") from e - def _run_training_sync( self, model: str, @@ -327,10 +310,9 @@ class HFFinetuningSingleDevice: # Load dataset logger.info(f"Loading dataset: {config.data_config.dataset_id}") - rows = await self._setup_data(config.data_config.dataset_id) - if not self.validate_dataset_format(rows): + if not self.validate_dataset_format(self.data): raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input") - logger.info(f"Loaded {len(rows)} rows from dataset") + logger.info(f"Loaded {len(self.data)} rows from dataset") # Initialize tokenizer logger.info(f"Initializing tokenizer for model: {model}") @@ -362,7 +344,7 @@ class HFFinetuningSingleDevice: # Create and preprocess dataset logger.info("Creating and preprocessing dataset") try: - ds = self._create_dataset(rows, config, provider_config) + ds = self._create_dataset(self.data, config, provider_config) ds = self._preprocess_dataset(ds, tokenizer, provider_config) logger.info(f"Dataset created with {len(ds)} examples") except Exception as e: diff --git a/llama_stack/providers/utils/scheduler.py b/llama_stack/providers/utils/scheduler.py index 845ab1f02..e661d2bf6 100644 --- a/llama_stack/providers/utils/scheduler.py +++ b/llama_stack/providers/utils/scheduler.py @@ -7,10 +7,12 @@ import abc import asyncio import functools +import multiprocessing import threading from collections.abc import Callable, Coroutine, Iterable from datetime import datetime, timezone from enum import Enum +from pathlib import Path from typing import Any, TypeAlias from pydantic import BaseModel @@ -54,7 +56,7 @@ _COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed} class Job: - def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler): + def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler | None): super().__init__() self.id = job_id self._type = job_type @@ -62,9 +64,38 @@ class Job: self._artifacts: list[JobArtifact] = [] self._logs: list[LogMessage] = [] self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)] + self._child_processes: list[multiprocessing.Process] = [] + self._world_size: int = 1 # Number of processes for distributed training + self.run_args: dict[str, Any] = {} # Dictionary to store run arguments @property - def handler(self) -> JobHandler: + def world_size(self) -> int: + return self._world_size + + @world_size.setter + def world_size(self, size: int) -> None: + self._world_size = size + + def add_child_process(self, process: multiprocessing.Process) -> None: + self._child_processes.append(process) + + def cancel(self) -> None: + """Cancel the job and all its child processes.""" + for process in self._child_processes: + if process.is_alive(): + process.terminate() + process.join(timeout=5) + self.status = JobStatus.failed + + def cleanup(self) -> None: + """Clean up any remaining child processes.""" + for process in self._child_processes: + if process.is_alive(): + process.terminate() + process.join(timeout=5) + + @property + def handler(self) -> JobHandler | None: return self._handler @property @@ -111,10 +142,6 @@ class Job: def append_log(self, message: LogMessage) -> None: self._logs.append(message) - # TODO: implement - def cancel(self) -> None: - raise NotImplementedError - class _SchedulerBackend(abc.ABC): @abc.abstractmethod @@ -148,8 +175,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend): def __init__(self, timeout: int = 5): self._timeout = timeout self._loop = asyncio.new_event_loop() - # There may be performance implications of using threads due to Python - # GIL; may need to measure if it's a real problem though self._thread = threading.Thread(target=self._run_loop, daemon=True) self._thread.start() @@ -158,7 +183,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend): self._loop.run_forever() # When stopping the loop, give tasks a chance to finish - # TODO: should we explicitly inform jobs of pending stoppage? for task in asyncio.all_tasks(self._loop): self._loop.run_until_complete(task) self._loop.close() @@ -167,7 +191,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend): self._loop.call_soon_threadsafe(self._loop.stop) self._thread.join() - # TODO: decouple scheduling and running the job def schedule( self, job: Job, @@ -179,6 +202,7 @@ class _NaiveSchedulerBackend(_SchedulerBackend): try: job.status = JobStatus.running await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb) + job.status = JobStatus.completed except Exception as e: on_log_message_cb(str(e)) job.status = JobStatus.failed @@ -196,8 +220,183 @@ class _NaiveSchedulerBackend(_SchedulerBackend): pass +class DistributedJobScheduler(_SchedulerBackend): + """A scheduler backend that supports distributed training jobs. + + This scheduler uses torchrun to handle distributed training process spawning and coordination. + torchrun automatically handles: + - Process spawning + - Environment variable setup + - Process group initialization + - Error handling and process cleanup + """ + + def __init__(self, timeout: int = 5): + self._timeout = timeout + self._loop = asyncio.new_event_loop() + self._thread = threading.Thread(target=self._run_loop, daemon=True) + self._thread.start() + self._active_jobs: dict[JobID, asyncio.subprocess.Process] = {} + + def _run_loop(self) -> None: + asyncio.set_event_loop(self._loop) + self._loop.run_forever() + + # When stopping the loop, give tasks a chance to finish + for task in asyncio.all_tasks(self._loop): + self._loop.run_until_complete(task) + self._loop.close() + + async def shutdown(self) -> None: + # Clean up any remaining processes + for process in self._active_jobs.values(): + if process.returncode is None: # Process is still running + process.terminate() + try: + await asyncio.wait_for(process.wait(), timeout=5) + except asyncio.TimeoutError: + process.kill() + await process.wait() + + self._loop.call_soon_threadsafe(self._loop.stop) + self._thread.join() + + def schedule( + self, + job: Job, + on_log_message_cb: Callable[[str], None], + on_status_change_cb: Callable[[JobStatus], None], + on_artifact_collected_cb: Callable[[JobArtifact], None], + ) -> None: + async def do(): + try: + job.status = JobStatus.running + + # If this is a distributed training job, use torchrun + if job.world_size > 1: + # Find the path to finetune_handler.py + from llama_stack.providers.inline.post_training.huggingface import finetune_handler + + handler_path = Path(finetune_handler.__file__) + + # Prepare arguments for the handler script + args = [ + "torchrun", + f"--nproc_per_node={job.world_size}", + "--master_addr=localhost", + "--master_port=29500", + str(handler_path), + ] + + # Add arguments from the job.run_args dictionary as proper command-line flags + for arg_name, arg_value in job.run_args.items(): + # Skip world_size as we've already handled it + if arg_name == "world_size": + continue + + if arg_value is not None: + # Handle boolean flags + if isinstance(arg_value, bool): + if arg_value: + args.append(f"--{arg_name}") + else: + # For non-boolean values, we add the argument as a separate flag and value + args.append(f"--{arg_name}") + args.append(str(arg_value)) + + # Launch torchrun using asyncio + on_log_message_cb(f"Launching distributed training with {job.world_size} processes") + on_log_message_cb(f"Command: {' '.join(args)}") + + # Make sure we capture stdout and stderr + process = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + + # Store process for this job + self._active_jobs[job.id] = process + + # Start monitoring in a separate task so we don't block + asyncio.create_task( + self._monitor_process(job, process, None, on_log_message_cb, on_status_change_cb) + ) + else: + # For single-device training, call the handler directly if provided + if job.handler: + await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb) + job.status = JobStatus.completed + else: + on_log_message_cb("No handler function provided for single-device training") + job.status = JobStatus.failed + except Exception as e: + on_log_message_cb(str(e)) + job.status = JobStatus.failed + logger.exception(f"Job {job.id} failed.") + + asyncio.run_coroutine_threadsafe(do(), self._loop) + + async def _monitor_process( + self, + job: Job, + process: asyncio.subprocess.Process, + script_path: Path | None, + on_log_message_cb: Callable[[str], None], + on_status_change_cb: Callable[[JobStatus], None], + ) -> None: + """Monitor a process until completion.""" + try: + # Stream output from the process if stdout is available + if process.stdout is not None: + while True: + line = await process.stdout.readline() + if not line and process.returncode is not None: + break + if line: + on_log_message_cb(line.decode().strip()) + else: + # If stdout is not available, just wait for the process to complete + on_log_message_cb("Process stdout not available, waiting for completion") + await process.wait() + + # Wait for process to complete if not already done + if process.returncode is None: + await process.wait() + + # Check if process failed + if process.returncode != 0: + on_log_message_cb(f"Training failed with return code {process.returncode}") + job.status = JobStatus.failed + else: + on_status_change_cb(JobStatus.completed) + job.status = JobStatus.completed + except Exception as e: + on_log_message_cb(f"Error monitoring process: {str(e)}") + job.status = JobStatus.failed + logger.exception(f"Error monitoring process for job {job.id}") + finally: + # Clean up temporary files + if script_path and script_path.exists(): + script_path.unlink() + + # Remove from active jobs + if job.id in self._active_jobs: + del self._active_jobs[job.id] + + def on_log_message_cb(self, job: Job, message: LogMessage) -> None: + pass + + def on_status_change_cb(self, job: Job, status: JobStatus) -> None: + pass + + def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None: + pass + + _BACKENDS = { "naive": _NaiveSchedulerBackend, + "distributed": DistributedJobScheduler, } @@ -230,11 +429,18 @@ class Scheduler: job.register_artifact(artifact) self._backend.on_artifact_collected_cb(job, artifact) - def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID: + def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler | None, run_params: dict[str, Any]) -> JobID: job = Job(type_, job_id, handler) if job.id in self._jobs: raise ValueError(f"Job {job.id} already exists") + # Set world size if provided + if "world_size" in run_params: + job.world_size = run_params["world_size"] + + # Store all run parameters in the job's run_args dictionary + job.run_args = run_params + self._jobs[job.id] = job job.status = JobStatus.scheduled self._backend.schedule(