feat: DistributedJobScheduler

rather than handling multi-GPU training within a recipe, distributed training should be one of our scheduler offerings. Introduce the DistributedJobScheduler which kicks off a `finetune_handler.py` script using torchrun. This handler processes the training args via argparse
and calls the right recipe as `post_training.py` used to do. Torchrun takes care of env variables like world_size, local_rank, etc.

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-06-12 13:59:06 -04:00
parent 6494658a10
commit ce48d47543
5 changed files with 534 additions and 143 deletions

View file

@ -0,0 +1,174 @@
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import asyncio
import json
import os
from typing import Any
from llama_stack.apis.post_training import TrainingConfig
from llama_stack.providers.inline.post_training.huggingface.config import HuggingFacePostTrainingConfig
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_multi_device import (
HFFinetuningMultiDevice,
)
from llama_stack.providers.utils.scheduler import JobStatus
async def train(
job_uuid,
model,
checkpoint_dir,
training_config,
provider_config,
algorithm_config,
data,
enable_nccl_debug=False,
nccl_debug_subsys="NONE",
):
"""Handler function for HuggingFace training that can be called by torchrun.
This is extracted from the supervised_fine_tune method in the HuggingFacePostTrainingImpl class.
It follows the same flow, but is designed to be called directly from a script.
Args:
job_uuid: Unique ID for this job
model: Model to train
checkpoint_dir: Directory to save checkpoints to
training_config: Training configuration
provider_config: Provider configuration
algorithm_config: Algorithm configuration
data: the dataset rows to be processed
enable_nccl_debug: Whether to enable NCCL debugging
nccl_debug_subsys: NCCL subsystem to debug
"""
# Get rank information when running distributed
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
parsed_data: list[dict[str, Any]] = json.loads(data)
# Set up callback functions with rank information
def on_log_message_cb(msg):
print(f"[RANK {local_rank}] {msg}", flush=True)
def on_status_change_cb(status):
print(f"[RANK {local_rank}] Status: {status}", flush=True)
def on_artifact_collected_cb(artifact):
print(f"[RANK {local_rank}] Artifact: {artifact}", flush=True)
on_log_message_cb("Starting HF finetuning")
recipe_obj = HFFinetuningMultiDevice(
job_uuid=job_uuid, enable_nccl_debug=enable_nccl_debug, nccl_debug_subsys=nccl_debug_subsys, data=parsed_data
)
resources_allocated, checkpoints = await recipe_obj.train(
model=model,
output_dir=checkpoint_dir,
job_uuid=job_uuid,
lora_config=algorithm_config,
config=training_config,
provider_config=provider_config,
)
def resources_stats_to_artifact(resources_stats):
return {
"type": "resources_stats",
"name": "resources_stats",
"metadata": resources_stats,
}
def checkpoint_to_artifact(checkpoint):
return {
"type": "checkpoint",
"name": checkpoint.identifier,
"uri": checkpoint.path,
"metadata": dict(checkpoint),
}
on_artifact_collected_cb(resources_stats_to_artifact(resources_allocated))
if checkpoints:
for checkpoint in checkpoints:
artifact = checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
on_status_change_cb(JobStatus.completed)
on_log_message_cb("HF finetuning completed")
async def main():
parser = argparse.ArgumentParser(description="Run HuggingFace training with torchrun.")
parser.add_argument("--job_uuid", type=str, required=True, help="Job UUID")
parser.add_argument("--model", type=str, required=True, help="Model to use")
parser.add_argument("--checkpoint_dir", type=str, help="Directory to save checkpoints")
parser.add_argument("--training_config", type=str, required=True, help="Training config JSON")
parser.add_argument("--provider_config", type=str, required=True, help="Provider config JSON")
parser.add_argument("--algorithm_config", type=str, help="Algorithm config JSON")
parser.add_argument("--enable_nccl_debug", action="store_true", help="Enable NCCL debugging")
parser.add_argument("--nccl_debug_subsys", type=str, default="NONE", help="NCCL subsystem to debug")
parser.add_argument("--data", type=str, required=True)
args = parser.parse_args()
# Parse JSON configs
try:
training_config = TrainingConfig.model_validate_json(args.training_config)
except Exception as e:
print(f"Error parsing training_config: {e}")
print(f"Received: {args.training_config}")
raise
try:
provider_config = HuggingFacePostTrainingConfig.model_validate_json(args.provider_config)
except Exception as e:
print(f"Error parsing provider_config: {e}")
print(f"Received: {args.provider_config}")
raise
algorithm_config = None
if args.algorithm_config:
try:
algorithm_config = json.loads(args.algorithm_config)
except json.JSONDecodeError as e:
print(f"Error parsing algorithm_config: {e}")
print(f"Received: {args.algorithm_config}")
raise
# In a real implementation, you would get these from somewhere
# For now, we'll pass None and handle it in the train function
datasetio_api = None
datasets_api = None
# Print arguments for debugging
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
if local_rank == 0: # Only the main process prints
print("Starting training with arguments:")
print(f" job_uuid: {args.job_uuid}")
print(f" model: {args.model}")
print(f" checkpoint_dir: {args.checkpoint_dir}")
print(f" enable_nccl_debug: {args.enable_nccl_debug}")
print(f" nccl_debug_subsys: {args.nccl_debug_subsys}")
await train(
job_uuid=args.job_uuid,
model=args.model,
checkpoint_dir=args.checkpoint_dir,
training_config=training_config,
provider_config=provider_config,
algorithm_config=algorithm_config,
datasetio_api=datasetio_api,
datasets_api=datasets_api,
enable_nccl_debug=args.enable_nccl_debug,
nccl_debug_subsys=args.nccl_debug_subsys,
data=args.data,
)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from enum import Enum
from typing import Any
@ -81,6 +82,61 @@ class HuggingFacePostTrainingImpl:
checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob:
from collections.abc import Callable, Coroutine
from typing import Any
# Type for the handler: async fn taking 3 Any args, returns Awaitable[None]
handler: (
Callable[
[Callable[[str], None], Callable[[SchedulerJobStatus], None], Callable[[JobArtifact], None]],
Coroutine[Any, Any, None],
]
| None
) = None
# Determine world size for distributed training
world_size = getattr(self.config, "world_size", 1)
# Choose the backend and recipe based on world size
if world_size > 1:
recipe = "multi"
# Create parameters for the handler script
run_params = {
"job_uuid": job_uuid,
"model": model,
"world_size": world_size,
"recipe": recipe,
}
# Add optional parameters
if checkpoint_dir is not None:
run_params["checkpoint_dir"] = checkpoint_dir
if training_config is not None:
run_params["training_config"] = training_config.model_dump_json()
if algorithm_config is not None:
run_params["algorithm_config"] = algorithm_config.model_dump_json()
# Add provider-specific configuration
run_params["provider_config"] = self.config.model_dump_json()
# Add NCCL debug settings if present
if hasattr(self.config, "enable_nccl_debug"):
run_params["enable_nccl_debug"] = self.config.enable_nccl_debug
if hasattr(self.config, "nccl_debug_subsys"):
run_params["nccl_debug_subsys"] = self.config.nccl_debug_subsys
# Initialize the scheduler with the distributed backend
self._scheduler = Scheduler(backend="distributed")
else:
self._scheduler = Scheduler(backend="naive")
# TODO: this can probably be cleaner
# Single-device training path
# Define a handler for single-device training
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
on_log_message_cb("Starting HF finetuning")
@ -94,8 +150,8 @@ class HuggingFacePostTrainingImpl:
job_uuid=job_uuid,
datasetio_api=self.datasetio_api,
datasets_api=self.datasets_api,
enable_nccl_debug=self.config.enable_nccl_debug,
nccl_debug_subsys=self.config.nccl_debug_subsys,
enable_nccl_debug=getattr(self.config, "enable_nccl_debug", False),
nccl_debug_subsys=getattr(self.config, "nccl_debug_subsys", "NONE"),
)
resources_allocated, checkpoints = await recipe.train(
@ -116,8 +172,40 @@ class HuggingFacePostTrainingImpl:
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("HF finetuning completed")
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
assert training_config.data_config is not None
data = self._setup_data(dataset_id=training_config.data_config.dataset_id)
json_data = json.dumps(data)
run_params["data"] = json_data
# Schedule the job with the regular scheduler and the handler
job_id = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler, run_params)
return PostTrainingJob(job_uuid=job_id)
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider.
Args:
dataset_id: ID of the dataset to load
Returns:
list: List of dataset rows
Raises:
RuntimeError: If dataset loading fails
"""
try:
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
async def preference_optimize(
self,

View file

@ -75,8 +75,6 @@ from transformers import (
)
from trl import SFTConfig, SFTTrainer
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
@ -191,8 +189,7 @@ class HFFinetuningMultiDevice:
def __init__(
self,
job_uuid: str,
datasetio_api: DatasetIO,
datasets_api: Datasets,
data: list[dict[str, Any]],
enable_nccl_debug: bool = False,
nccl_debug_subsys: str = "NONE",
):
@ -203,8 +200,7 @@ class HFFinetuningMultiDevice:
datasetio_api: API for dataset I/O operations
datasets_api: API for dataset management
"""
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.data = data
self.job_uuid = job_uuid
self.enable_nccl_debug = enable_nccl_debug
self.nccl_debug_subsys = nccl_debug_subsys
@ -408,29 +404,6 @@ class HFFinetuningMultiDevice:
num_proc=1, # Single process to avoid issues
)
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider.
Args:
dataset_id: ID of the dataset to load
Returns:
list: List of dataset rows
Raises:
RuntimeError: If dataset loading fails
"""
try:
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
def _run_training_sync(
self,
local_rank: int, # First parameter must be local_rank for spawn
@ -627,10 +600,9 @@ class HFFinetuningMultiDevice:
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await self._setup_data(config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
if not self.validate_dataset_format(self.data):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
logger.info(f"Loaded {len(rows)} rows from dataset")
logger.info(f"Loaded {len(self.data)} rows from dataset")
# Initialize tokenizer
logger.info(f"Initializing tokenizer for model: {model}")
@ -662,7 +634,7 @@ class HFFinetuningMultiDevice:
# Create and preprocess dataset
logger.info("Creating and preprocessing dataset")
try:
ds = self._create_dataset(rows, config, provider_config)
ds = self._create_dataset(self.data, config, provider_config)
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
logger.info(f"Dataset created with {len(ds)} examples")
except Exception as e:
@ -1021,37 +993,16 @@ class HFFinetuningMultiDevice:
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[dict[str, Any], list[Checkpoint] | None]:
"""Train a model using HuggingFace's SFTTrainer with distributed training.
The distributed training setup works as follows:
1. Parse the device list to determine number of GPUs
2. Use torch.multiprocessing.spawn to launch one process per GPU
3. Each process runs _run_training_sync with a unique rank
4. The processes coordinate through NCCL backend
5. FSDP handles model sharding across GPUs
6. Only rank 0 handles saving checkpoints and logging
Args:
model: The model identifier to load
output_dir: Optional directory to save checkpoints
job_uuid: Unique identifier for this training job
lora_config: LoRA configuration for parameter-efficient fine-tuning
config: General training configuration
provider_config: Provider-specific configuration
Returns:
tuple: (memory_stats, checkpoints)
"""
"""Train a model using HuggingFace's SFTTrainer with distributed training."""
if provider_config.distributed_backend != "fsdp":
raise RuntimeError("Must enable FSDP as distributed backend to use this recipe")
# Configure NCCL logging based on debug settings
configure_nccl_logging(self.enable_nccl_debug, self.nccl_debug_subsys)
# Parse device list to determine number of GPUs
devices = [d.strip() for d in provider_config.device.split(",")]
world_size = len(devices)
logger.info(f"Using {world_size} devices: {devices}")
# Get local rank and world size from environment variables
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
output_dir_path = None
if output_dir:
@ -1081,32 +1032,22 @@ class HFFinetuningMultiDevice:
raise ValueError("DataConfig is required for training")
try:
# Launch distributed training processes
# torch.multiprocessing.spawn will:
# 1. Create world_size number of processes
# 2. Call _run_training_sync for each process
# 3. Pass unique local_rank to each process
# 4. Handle process coordination and cleanup
logger.info("Starting distributed training processes")
torch.multiprocessing.spawn(
self._run_training_sync,
args=(
world_size,
model,
provider_config.model_dump(),
peft_config,
config.model_dump(),
output_dir_path,
),
nprocs=world_size,
join=True, # Wait for all processes to complete
# Run training for this process
await self._run_training(
model=model,
provider_config=provider_config.model_dump(),
peft_config=peft_config,
config=config.model_dump(),
output_dir_path=output_dir_path,
local_rank=local_rank,
world_size=world_size,
)
memory_stats["after_training"] = get_memory_stats(torch.device("cuda:0"))
# Create checkpoint on rank 0
# Only create checkpoint on rank 0
checkpoints = None
if output_dir_path:
if output_dir_path and local_rank == 0:
checkpoint = Checkpoint(
identifier=f"{model}-sft-{config.n_epochs}",
created_at=datetime.now(timezone.utc),

View file

@ -37,8 +37,6 @@ from transformers import (
)
from trl import SFTConfig, SFTTrainer
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
@ -136,11 +134,9 @@ class HFFinetuningSingleDevice:
def __init__(
self,
job_uuid: str,
datasetio_api: DatasetIO,
datasets_api: Datasets,
data: list[dict[str, Any]],
):
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.data = data
self.job_uuid = job_uuid
def validate_dataset_format(self, rows: list[dict]) -> bool:
@ -262,19 +258,6 @@ class HFFinetuningSingleDevice:
remove_columns=ds.column_names,
)
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider"""
try:
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
def _run_training_sync(
self,
model: str,
@ -327,10 +310,9 @@ class HFFinetuningSingleDevice:
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await self._setup_data(config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
if not self.validate_dataset_format(self.data):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
logger.info(f"Loaded {len(rows)} rows from dataset")
logger.info(f"Loaded {len(self.data)} rows from dataset")
# Initialize tokenizer
logger.info(f"Initializing tokenizer for model: {model}")
@ -362,7 +344,7 @@ class HFFinetuningSingleDevice:
# Create and preprocess dataset
logger.info("Creating and preprocessing dataset")
try:
ds = self._create_dataset(rows, config, provider_config)
ds = self._create_dataset(self.data, config, provider_config)
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
logger.info(f"Dataset created with {len(ds)} examples")
except Exception as e:

View file

@ -7,10 +7,12 @@
import abc
import asyncio
import functools
import multiprocessing
import threading
from collections.abc import Callable, Coroutine, Iterable
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
from typing import Any, TypeAlias
from pydantic import BaseModel
@ -54,7 +56,7 @@ _COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
class Job:
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler):
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler | None):
super().__init__()
self.id = job_id
self._type = job_type
@ -62,9 +64,38 @@ class Job:
self._artifacts: list[JobArtifact] = []
self._logs: list[LogMessage] = []
self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
self._child_processes: list[multiprocessing.Process] = []
self._world_size: int = 1 # Number of processes for distributed training
self.run_args: dict[str, Any] = {} # Dictionary to store run arguments
@property
def handler(self) -> JobHandler:
def world_size(self) -> int:
return self._world_size
@world_size.setter
def world_size(self, size: int) -> None:
self._world_size = size
def add_child_process(self, process: multiprocessing.Process) -> None:
self._child_processes.append(process)
def cancel(self) -> None:
"""Cancel the job and all its child processes."""
for process in self._child_processes:
if process.is_alive():
process.terminate()
process.join(timeout=5)
self.status = JobStatus.failed
def cleanup(self) -> None:
"""Clean up any remaining child processes."""
for process in self._child_processes:
if process.is_alive():
process.terminate()
process.join(timeout=5)
@property
def handler(self) -> JobHandler | None:
return self._handler
@property
@ -111,10 +142,6 @@ class Job:
def append_log(self, message: LogMessage) -> None:
self._logs.append(message)
# TODO: implement
def cancel(self) -> None:
raise NotImplementedError
class _SchedulerBackend(abc.ABC):
@abc.abstractmethod
@ -148,8 +175,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
def __init__(self, timeout: int = 5):
self._timeout = timeout
self._loop = asyncio.new_event_loop()
# There may be performance implications of using threads due to Python
# GIL; may need to measure if it's a real problem though
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
@ -158,7 +183,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
self._loop.run_forever()
# When stopping the loop, give tasks a chance to finish
# TODO: should we explicitly inform jobs of pending stoppage?
for task in asyncio.all_tasks(self._loop):
self._loop.run_until_complete(task)
self._loop.close()
@ -167,7 +191,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
# TODO: decouple scheduling and running the job
def schedule(
self,
job: Job,
@ -179,6 +202,7 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
try:
job.status = JobStatus.running
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
job.status = JobStatus.completed
except Exception as e:
on_log_message_cb(str(e))
job.status = JobStatus.failed
@ -196,8 +220,183 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
pass
class DistributedJobScheduler(_SchedulerBackend):
"""A scheduler backend that supports distributed training jobs.
This scheduler uses torchrun to handle distributed training process spawning and coordination.
torchrun automatically handles:
- Process spawning
- Environment variable setup
- Process group initialization
- Error handling and process cleanup
"""
def __init__(self, timeout: int = 5):
self._timeout = timeout
self._loop = asyncio.new_event_loop()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
self._active_jobs: dict[JobID, asyncio.subprocess.Process] = {}
def _run_loop(self) -> None:
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
# When stopping the loop, give tasks a chance to finish
for task in asyncio.all_tasks(self._loop):
self._loop.run_until_complete(task)
self._loop.close()
async def shutdown(self) -> None:
# Clean up any remaining processes
for process in self._active_jobs.values():
if process.returncode is None: # Process is still running
process.terminate()
try:
await asyncio.wait_for(process.wait(), timeout=5)
except asyncio.TimeoutError:
process.kill()
await process.wait()
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
def schedule(
self,
job: Job,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
on_artifact_collected_cb: Callable[[JobArtifact], None],
) -> None:
async def do():
try:
job.status = JobStatus.running
# If this is a distributed training job, use torchrun
if job.world_size > 1:
# Find the path to finetune_handler.py
from llama_stack.providers.inline.post_training.huggingface import finetune_handler
handler_path = Path(finetune_handler.__file__)
# Prepare arguments for the handler script
args = [
"torchrun",
f"--nproc_per_node={job.world_size}",
"--master_addr=localhost",
"--master_port=29500",
str(handler_path),
]
# Add arguments from the job.run_args dictionary as proper command-line flags
for arg_name, arg_value in job.run_args.items():
# Skip world_size as we've already handled it
if arg_name == "world_size":
continue
if arg_value is not None:
# Handle boolean flags
if isinstance(arg_value, bool):
if arg_value:
args.append(f"--{arg_name}")
else:
# For non-boolean values, we add the argument as a separate flag and value
args.append(f"--{arg_name}")
args.append(str(arg_value))
# Launch torchrun using asyncio
on_log_message_cb(f"Launching distributed training with {job.world_size} processes")
on_log_message_cb(f"Command: {' '.join(args)}")
# Make sure we capture stdout and stderr
process = await asyncio.create_subprocess_exec(
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
# Store process for this job
self._active_jobs[job.id] = process
# Start monitoring in a separate task so we don't block
asyncio.create_task(
self._monitor_process(job, process, None, on_log_message_cb, on_status_change_cb)
)
else:
# For single-device training, call the handler directly if provided
if job.handler:
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
job.status = JobStatus.completed
else:
on_log_message_cb("No handler function provided for single-device training")
job.status = JobStatus.failed
except Exception as e:
on_log_message_cb(str(e))
job.status = JobStatus.failed
logger.exception(f"Job {job.id} failed.")
asyncio.run_coroutine_threadsafe(do(), self._loop)
async def _monitor_process(
self,
job: Job,
process: asyncio.subprocess.Process,
script_path: Path | None,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
) -> None:
"""Monitor a process until completion."""
try:
# Stream output from the process if stdout is available
if process.stdout is not None:
while True:
line = await process.stdout.readline()
if not line and process.returncode is not None:
break
if line:
on_log_message_cb(line.decode().strip())
else:
# If stdout is not available, just wait for the process to complete
on_log_message_cb("Process stdout not available, waiting for completion")
await process.wait()
# Wait for process to complete if not already done
if process.returncode is None:
await process.wait()
# Check if process failed
if process.returncode != 0:
on_log_message_cb(f"Training failed with return code {process.returncode}")
job.status = JobStatus.failed
else:
on_status_change_cb(JobStatus.completed)
job.status = JobStatus.completed
except Exception as e:
on_log_message_cb(f"Error monitoring process: {str(e)}")
job.status = JobStatus.failed
logger.exception(f"Error monitoring process for job {job.id}")
finally:
# Clean up temporary files
if script_path and script_path.exists():
script_path.unlink()
# Remove from active jobs
if job.id in self._active_jobs:
del self._active_jobs[job.id]
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
pass
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
pass
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
pass
_BACKENDS = {
"naive": _NaiveSchedulerBackend,
"distributed": DistributedJobScheduler,
}
@ -230,11 +429,18 @@ class Scheduler:
job.register_artifact(artifact)
self._backend.on_artifact_collected_cb(job, artifact)
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID:
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler | None, run_params: dict[str, Any]) -> JobID:
job = Job(type_, job_id, handler)
if job.id in self._jobs:
raise ValueError(f"Job {job.id} already exists")
# Set world size if provided
if "world_size" in run_params:
job.world_size = run_params["world_size"]
# Store all run parameters in the job's run_args dictionary
job.run_args = run_params
self._jobs[job.id] = job
job.status = JobStatus.scheduled
self._backend.schedule(