mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
Merge ce48d47543
into 90d03552d4
This commit is contained in:
commit
6f38d12853
8 changed files with 1598 additions and 63 deletions
|
@ -57,7 +57,7 @@ class HuggingFacePostTrainingConfig(BaseModel):
|
||||||
|
|
||||||
# L2 regularization coefficient
|
# L2 regularization coefficient
|
||||||
# Helps prevent overfitting
|
# Helps prevent overfitting
|
||||||
weight_decay: float = 0.01
|
weight_decay: float = 0.00
|
||||||
|
|
||||||
# Number of worker processes for data loading
|
# Number of worker processes for data loading
|
||||||
# Higher values can improve data loading speed but increase memory usage
|
# Higher values can improve data loading speed but increase memory usage
|
||||||
|
@ -67,6 +67,17 @@ class HuggingFacePostTrainingConfig(BaseModel):
|
||||||
# Can improve data transfer speed to GPU but uses more memory
|
# Can improve data transfer speed to GPU but uses more memory
|
||||||
dataloader_pin_memory: bool = True
|
dataloader_pin_memory: bool = True
|
||||||
|
|
||||||
|
# Recipe type for training (single or multi device)
|
||||||
|
recipe: str = "single"
|
||||||
|
|
||||||
|
# NCCL debug configuration for distributed training
|
||||||
|
# Enable detailed NCCL logging for debugging distributed training issues
|
||||||
|
enable_nccl_debug: bool = False
|
||||||
|
|
||||||
|
# NCCL subsystems to debug (NONE, ALL, INIT, COLL, P2P, SHM, NET)
|
||||||
|
# Controls which NCCL components generate debug output
|
||||||
|
nccl_debug_subsys: str = "NONE"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||||
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}
|
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu", "recipe": "single"}
|
||||||
|
|
174
llama_stack/providers/inline/post_training/huggingface/finetune_handler.py
Executable file
174
llama_stack/providers/inline/post_training/huggingface/finetune_handler.py
Executable file
|
@ -0,0 +1,174 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.post_training import TrainingConfig
|
||||||
|
from llama_stack.providers.inline.post_training.huggingface.config import HuggingFacePostTrainingConfig
|
||||||
|
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_multi_device import (
|
||||||
|
HFFinetuningMultiDevice,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.scheduler import JobStatus
|
||||||
|
|
||||||
|
|
||||||
|
async def train(
|
||||||
|
job_uuid,
|
||||||
|
model,
|
||||||
|
checkpoint_dir,
|
||||||
|
training_config,
|
||||||
|
provider_config,
|
||||||
|
algorithm_config,
|
||||||
|
data,
|
||||||
|
enable_nccl_debug=False,
|
||||||
|
nccl_debug_subsys="NONE",
|
||||||
|
):
|
||||||
|
"""Handler function for HuggingFace training that can be called by torchrun.
|
||||||
|
|
||||||
|
This is extracted from the supervised_fine_tune method in the HuggingFacePostTrainingImpl class.
|
||||||
|
It follows the same flow, but is designed to be called directly from a script.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_uuid: Unique ID for this job
|
||||||
|
model: Model to train
|
||||||
|
checkpoint_dir: Directory to save checkpoints to
|
||||||
|
training_config: Training configuration
|
||||||
|
provider_config: Provider configuration
|
||||||
|
algorithm_config: Algorithm configuration
|
||||||
|
data: the dataset rows to be processed
|
||||||
|
enable_nccl_debug: Whether to enable NCCL debugging
|
||||||
|
nccl_debug_subsys: NCCL subsystem to debug
|
||||||
|
"""
|
||||||
|
# Get rank information when running distributed
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||||
|
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
parsed_data: list[dict[str, Any]] = json.loads(data)
|
||||||
|
|
||||||
|
# Set up callback functions with rank information
|
||||||
|
def on_log_message_cb(msg):
|
||||||
|
print(f"[RANK {local_rank}] {msg}", flush=True)
|
||||||
|
|
||||||
|
def on_status_change_cb(status):
|
||||||
|
print(f"[RANK {local_rank}] Status: {status}", flush=True)
|
||||||
|
|
||||||
|
def on_artifact_collected_cb(artifact):
|
||||||
|
print(f"[RANK {local_rank}] Artifact: {artifact}", flush=True)
|
||||||
|
|
||||||
|
on_log_message_cb("Starting HF finetuning")
|
||||||
|
|
||||||
|
recipe_obj = HFFinetuningMultiDevice(
|
||||||
|
job_uuid=job_uuid, enable_nccl_debug=enable_nccl_debug, nccl_debug_subsys=nccl_debug_subsys, data=parsed_data
|
||||||
|
)
|
||||||
|
|
||||||
|
resources_allocated, checkpoints = await recipe_obj.train(
|
||||||
|
model=model,
|
||||||
|
output_dir=checkpoint_dir,
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
lora_config=algorithm_config,
|
||||||
|
config=training_config,
|
||||||
|
provider_config=provider_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def resources_stats_to_artifact(resources_stats):
|
||||||
|
return {
|
||||||
|
"type": "resources_stats",
|
||||||
|
"name": "resources_stats",
|
||||||
|
"metadata": resources_stats,
|
||||||
|
}
|
||||||
|
|
||||||
|
def checkpoint_to_artifact(checkpoint):
|
||||||
|
return {
|
||||||
|
"type": "checkpoint",
|
||||||
|
"name": checkpoint.identifier,
|
||||||
|
"uri": checkpoint.path,
|
||||||
|
"metadata": dict(checkpoint),
|
||||||
|
}
|
||||||
|
|
||||||
|
on_artifact_collected_cb(resources_stats_to_artifact(resources_allocated))
|
||||||
|
if checkpoints:
|
||||||
|
for checkpoint in checkpoints:
|
||||||
|
artifact = checkpoint_to_artifact(checkpoint)
|
||||||
|
on_artifact_collected_cb(artifact)
|
||||||
|
|
||||||
|
on_status_change_cb(JobStatus.completed)
|
||||||
|
on_log_message_cb("HF finetuning completed")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Run HuggingFace training with torchrun.")
|
||||||
|
parser.add_argument("--job_uuid", type=str, required=True, help="Job UUID")
|
||||||
|
parser.add_argument("--model", type=str, required=True, help="Model to use")
|
||||||
|
parser.add_argument("--checkpoint_dir", type=str, help="Directory to save checkpoints")
|
||||||
|
parser.add_argument("--training_config", type=str, required=True, help="Training config JSON")
|
||||||
|
parser.add_argument("--provider_config", type=str, required=True, help="Provider config JSON")
|
||||||
|
parser.add_argument("--algorithm_config", type=str, help="Algorithm config JSON")
|
||||||
|
parser.add_argument("--enable_nccl_debug", action="store_true", help="Enable NCCL debugging")
|
||||||
|
parser.add_argument("--nccl_debug_subsys", type=str, default="NONE", help="NCCL subsystem to debug")
|
||||||
|
parser.add_argument("--data", type=str, required=True)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Parse JSON configs
|
||||||
|
try:
|
||||||
|
training_config = TrainingConfig.model_validate_json(args.training_config)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing training_config: {e}")
|
||||||
|
print(f"Received: {args.training_config}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider_config = HuggingFacePostTrainingConfig.model_validate_json(args.provider_config)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing provider_config: {e}")
|
||||||
|
print(f"Received: {args.provider_config}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
algorithm_config = None
|
||||||
|
if args.algorithm_config:
|
||||||
|
try:
|
||||||
|
algorithm_config = json.loads(args.algorithm_config)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"Error parsing algorithm_config: {e}")
|
||||||
|
print(f"Received: {args.algorithm_config}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# In a real implementation, you would get these from somewhere
|
||||||
|
# For now, we'll pass None and handle it in the train function
|
||||||
|
datasetio_api = None
|
||||||
|
datasets_api = None
|
||||||
|
|
||||||
|
# Print arguments for debugging
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||||
|
if local_rank == 0: # Only the main process prints
|
||||||
|
print("Starting training with arguments:")
|
||||||
|
print(f" job_uuid: {args.job_uuid}")
|
||||||
|
print(f" model: {args.model}")
|
||||||
|
print(f" checkpoint_dir: {args.checkpoint_dir}")
|
||||||
|
print(f" enable_nccl_debug: {args.enable_nccl_debug}")
|
||||||
|
print(f" nccl_debug_subsys: {args.nccl_debug_subsys}")
|
||||||
|
|
||||||
|
await train(
|
||||||
|
job_uuid=args.job_uuid,
|
||||||
|
model=args.model,
|
||||||
|
checkpoint_dir=args.checkpoint_dir,
|
||||||
|
training_config=training_config,
|
||||||
|
provider_config=provider_config,
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
datasetio_api=datasetio_api,
|
||||||
|
datasets_api=datasets_api,
|
||||||
|
enable_nccl_debug=args.enable_nccl_debug,
|
||||||
|
nccl_debug_subsys=args.nccl_debug_subsys,
|
||||||
|
data=args.data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
|
@ -3,6 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -22,6 +23,7 @@ from llama_stack.apis.post_training import (
|
||||||
from llama_stack.providers.inline.post_training.huggingface.config import (
|
from llama_stack.providers.inline.post_training.huggingface.config import (
|
||||||
HuggingFacePostTrainingConfig,
|
HuggingFacePostTrainingConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_multi_device import HFFinetuningMultiDevice
|
||||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
||||||
HFFinetuningSingleDevice,
|
HFFinetuningSingleDevice,
|
||||||
)
|
)
|
||||||
|
@ -80,6 +82,61 @@ class HuggingFacePostTrainingImpl:
|
||||||
checkpoint_dir: str | None = None,
|
checkpoint_dir: str | None = None,
|
||||||
algorithm_config: AlgorithmConfig | None = None,
|
algorithm_config: AlgorithmConfig | None = None,
|
||||||
) -> PostTrainingJob:
|
) -> PostTrainingJob:
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# Type for the handler: async fn taking 3 Any args, returns Awaitable[None]
|
||||||
|
handler: (
|
||||||
|
Callable[
|
||||||
|
[Callable[[str], None], Callable[[SchedulerJobStatus], None], Callable[[JobArtifact], None]],
|
||||||
|
Coroutine[Any, Any, None],
|
||||||
|
]
|
||||||
|
| None
|
||||||
|
) = None
|
||||||
|
|
||||||
|
# Determine world size for distributed training
|
||||||
|
world_size = getattr(self.config, "world_size", 1)
|
||||||
|
|
||||||
|
# Choose the backend and recipe based on world size
|
||||||
|
if world_size > 1:
|
||||||
|
recipe = "multi"
|
||||||
|
|
||||||
|
# Create parameters for the handler script
|
||||||
|
run_params = {
|
||||||
|
"job_uuid": job_uuid,
|
||||||
|
"model": model,
|
||||||
|
"world_size": world_size,
|
||||||
|
"recipe": recipe,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional parameters
|
||||||
|
if checkpoint_dir is not None:
|
||||||
|
run_params["checkpoint_dir"] = checkpoint_dir
|
||||||
|
|
||||||
|
if training_config is not None:
|
||||||
|
run_params["training_config"] = training_config.model_dump_json()
|
||||||
|
|
||||||
|
if algorithm_config is not None:
|
||||||
|
run_params["algorithm_config"] = algorithm_config.model_dump_json()
|
||||||
|
|
||||||
|
# Add provider-specific configuration
|
||||||
|
run_params["provider_config"] = self.config.model_dump_json()
|
||||||
|
|
||||||
|
# Add NCCL debug settings if present
|
||||||
|
if hasattr(self.config, "enable_nccl_debug"):
|
||||||
|
run_params["enable_nccl_debug"] = self.config.enable_nccl_debug
|
||||||
|
|
||||||
|
if hasattr(self.config, "nccl_debug_subsys"):
|
||||||
|
run_params["nccl_debug_subsys"] = self.config.nccl_debug_subsys
|
||||||
|
|
||||||
|
# Initialize the scheduler with the distributed backend
|
||||||
|
self._scheduler = Scheduler(backend="distributed")
|
||||||
|
else:
|
||||||
|
self._scheduler = Scheduler(backend="naive")
|
||||||
|
|
||||||
|
# TODO: this can probably be cleaner
|
||||||
|
# Single-device training path
|
||||||
|
# Define a handler for single-device training
|
||||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||||
on_log_message_cb("Starting HF finetuning")
|
on_log_message_cb("Starting HF finetuning")
|
||||||
|
|
||||||
|
@ -88,6 +145,14 @@ class HuggingFacePostTrainingImpl:
|
||||||
datasetio_api=self.datasetio_api,
|
datasetio_api=self.datasetio_api,
|
||||||
datasets_api=self.datasets_api,
|
datasets_api=self.datasets_api,
|
||||||
)
|
)
|
||||||
|
if self.config.recipe == "multi":
|
||||||
|
recipe = HFFinetuningMultiDevice(
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
datasetio_api=self.datasetio_api,
|
||||||
|
datasets_api=self.datasets_api,
|
||||||
|
enable_nccl_debug=getattr(self.config, "enable_nccl_debug", False),
|
||||||
|
nccl_debug_subsys=getattr(self.config, "nccl_debug_subsys", "NONE"),
|
||||||
|
)
|
||||||
|
|
||||||
resources_allocated, checkpoints = await recipe.train(
|
resources_allocated, checkpoints = await recipe.train(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -107,8 +172,40 @@ class HuggingFacePostTrainingImpl:
|
||||||
on_status_change_cb(SchedulerJobStatus.completed)
|
on_status_change_cb(SchedulerJobStatus.completed)
|
||||||
on_log_message_cb("HF finetuning completed")
|
on_log_message_cb("HF finetuning completed")
|
||||||
|
|
||||||
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
assert training_config.data_config is not None
|
||||||
return PostTrainingJob(job_uuid=job_uuid)
|
data = self._setup_data(dataset_id=training_config.data_config.dataset_id)
|
||||||
|
|
||||||
|
json_data = json.dumps(data)
|
||||||
|
|
||||||
|
run_params["data"] = json_data
|
||||||
|
|
||||||
|
# Schedule the job with the regular scheduler and the handler
|
||||||
|
job_id = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler, run_params)
|
||||||
|
|
||||||
|
return PostTrainingJob(job_uuid=job_id)
|
||||||
|
|
||||||
|
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
|
||||||
|
"""Load dataset from llama stack dataset provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: ID of the dataset to load
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of dataset rows
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If dataset loading fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
all_rows = await self.datasetio_api.iterrows(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
limit=-1,
|
||||||
|
)
|
||||||
|
if not isinstance(all_rows.data, list):
|
||||||
|
raise RuntimeError("Expected dataset data to be a list")
|
||||||
|
return all_rows.data
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
|
||||||
|
|
||||||
async def preference_optimize(
|
async def preference_optimize(
|
||||||
self,
|
self,
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -37,8 +37,6 @@ from transformers import (
|
||||||
)
|
)
|
||||||
from trl import SFTConfig, SFTTrainer
|
from trl import SFTConfig, SFTTrainer
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
|
||||||
from llama_stack.apis.datasets import Datasets
|
|
||||||
from llama_stack.apis.post_training import (
|
from llama_stack.apis.post_training import (
|
||||||
Checkpoint,
|
Checkpoint,
|
||||||
DataConfig,
|
DataConfig,
|
||||||
|
@ -136,11 +134,9 @@ class HFFinetuningSingleDevice:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
datasetio_api: DatasetIO,
|
data: list[dict[str, Any]],
|
||||||
datasets_api: Datasets,
|
|
||||||
):
|
):
|
||||||
self.datasetio_api = datasetio_api
|
self.data = data
|
||||||
self.datasets_api = datasets_api
|
|
||||||
self.job_uuid = job_uuid
|
self.job_uuid = job_uuid
|
||||||
|
|
||||||
def validate_dataset_format(self, rows: list[dict]) -> bool:
|
def validate_dataset_format(self, rows: list[dict]) -> bool:
|
||||||
|
@ -262,19 +258,6 @@ class HFFinetuningSingleDevice:
|
||||||
remove_columns=ds.column_names,
|
remove_columns=ds.column_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
|
|
||||||
"""Load dataset from llama stack dataset provider"""
|
|
||||||
try:
|
|
||||||
all_rows = await self.datasetio_api.iterrows(
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
limit=-1,
|
|
||||||
)
|
|
||||||
if not isinstance(all_rows.data, list):
|
|
||||||
raise RuntimeError("Expected dataset data to be a list")
|
|
||||||
return all_rows.data
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
|
|
||||||
|
|
||||||
def _run_training_sync(
|
def _run_training_sync(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -327,10 +310,9 @@ class HFFinetuningSingleDevice:
|
||||||
|
|
||||||
# Load dataset
|
# Load dataset
|
||||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||||
rows = await self._setup_data(config.data_config.dataset_id)
|
if not self.validate_dataset_format(self.data):
|
||||||
if not self.validate_dataset_format(rows):
|
|
||||||
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
||||||
logger.info(f"Loaded {len(rows)} rows from dataset")
|
logger.info(f"Loaded {len(self.data)} rows from dataset")
|
||||||
|
|
||||||
# Initialize tokenizer
|
# Initialize tokenizer
|
||||||
logger.info(f"Initializing tokenizer for model: {model}")
|
logger.info(f"Initializing tokenizer for model: {model}")
|
||||||
|
@ -362,7 +344,7 @@ class HFFinetuningSingleDevice:
|
||||||
# Create and preprocess dataset
|
# Create and preprocess dataset
|
||||||
logger.info("Creating and preprocessing dataset")
|
logger.info("Creating and preprocessing dataset")
|
||||||
try:
|
try:
|
||||||
ds = self._create_dataset(rows, config, provider_config)
|
ds = self._create_dataset(self.data, config, provider_config)
|
||||||
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
|
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
|
||||||
logger.info(f"Dataset created with {len(ds)} examples")
|
logger.info(f"Dataset created with {len(ds)} examples")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -7,10 +7,12 @@
|
||||||
import abc
|
import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
|
import multiprocessing
|
||||||
import threading
|
import threading
|
||||||
from collections.abc import Callable, Coroutine, Iterable
|
from collections.abc import Callable, Coroutine, Iterable
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, TypeAlias
|
from typing import Any, TypeAlias
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -54,7 +56,7 @@ _COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
|
||||||
|
|
||||||
|
|
||||||
class Job:
|
class Job:
|
||||||
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler):
|
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler | None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.id = job_id
|
self.id = job_id
|
||||||
self._type = job_type
|
self._type = job_type
|
||||||
|
@ -62,9 +64,38 @@ class Job:
|
||||||
self._artifacts: list[JobArtifact] = []
|
self._artifacts: list[JobArtifact] = []
|
||||||
self._logs: list[LogMessage] = []
|
self._logs: list[LogMessage] = []
|
||||||
self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
|
self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
|
||||||
|
self._child_processes: list[multiprocessing.Process] = []
|
||||||
|
self._world_size: int = 1 # Number of processes for distributed training
|
||||||
|
self.run_args: dict[str, Any] = {} # Dictionary to store run arguments
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def handler(self) -> JobHandler:
|
def world_size(self) -> int:
|
||||||
|
return self._world_size
|
||||||
|
|
||||||
|
@world_size.setter
|
||||||
|
def world_size(self, size: int) -> None:
|
||||||
|
self._world_size = size
|
||||||
|
|
||||||
|
def add_child_process(self, process: multiprocessing.Process) -> None:
|
||||||
|
self._child_processes.append(process)
|
||||||
|
|
||||||
|
def cancel(self) -> None:
|
||||||
|
"""Cancel the job and all its child processes."""
|
||||||
|
for process in self._child_processes:
|
||||||
|
if process.is_alive():
|
||||||
|
process.terminate()
|
||||||
|
process.join(timeout=5)
|
||||||
|
self.status = JobStatus.failed
|
||||||
|
|
||||||
|
def cleanup(self) -> None:
|
||||||
|
"""Clean up any remaining child processes."""
|
||||||
|
for process in self._child_processes:
|
||||||
|
if process.is_alive():
|
||||||
|
process.terminate()
|
||||||
|
process.join(timeout=5)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def handler(self) -> JobHandler | None:
|
||||||
return self._handler
|
return self._handler
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -111,10 +142,6 @@ class Job:
|
||||||
def append_log(self, message: LogMessage) -> None:
|
def append_log(self, message: LogMessage) -> None:
|
||||||
self._logs.append(message)
|
self._logs.append(message)
|
||||||
|
|
||||||
# TODO: implement
|
|
||||||
def cancel(self) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class _SchedulerBackend(abc.ABC):
|
class _SchedulerBackend(abc.ABC):
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
@ -148,8 +175,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
|
||||||
def __init__(self, timeout: int = 5):
|
def __init__(self, timeout: int = 5):
|
||||||
self._timeout = timeout
|
self._timeout = timeout
|
||||||
self._loop = asyncio.new_event_loop()
|
self._loop = asyncio.new_event_loop()
|
||||||
# There may be performance implications of using threads due to Python
|
|
||||||
# GIL; may need to measure if it's a real problem though
|
|
||||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
|
|
||||||
|
@ -158,7 +183,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
|
||||||
self._loop.run_forever()
|
self._loop.run_forever()
|
||||||
|
|
||||||
# When stopping the loop, give tasks a chance to finish
|
# When stopping the loop, give tasks a chance to finish
|
||||||
# TODO: should we explicitly inform jobs of pending stoppage?
|
|
||||||
for task in asyncio.all_tasks(self._loop):
|
for task in asyncio.all_tasks(self._loop):
|
||||||
self._loop.run_until_complete(task)
|
self._loop.run_until_complete(task)
|
||||||
self._loop.close()
|
self._loop.close()
|
||||||
|
@ -167,7 +191,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
|
||||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||||
self._thread.join()
|
self._thread.join()
|
||||||
|
|
||||||
# TODO: decouple scheduling and running the job
|
|
||||||
def schedule(
|
def schedule(
|
||||||
self,
|
self,
|
||||||
job: Job,
|
job: Job,
|
||||||
|
@ -179,6 +202,7 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
|
||||||
try:
|
try:
|
||||||
job.status = JobStatus.running
|
job.status = JobStatus.running
|
||||||
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
|
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
|
||||||
|
job.status = JobStatus.completed
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
on_log_message_cb(str(e))
|
on_log_message_cb(str(e))
|
||||||
job.status = JobStatus.failed
|
job.status = JobStatus.failed
|
||||||
|
@ -196,8 +220,183 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedJobScheduler(_SchedulerBackend):
|
||||||
|
"""A scheduler backend that supports distributed training jobs.
|
||||||
|
|
||||||
|
This scheduler uses torchrun to handle distributed training process spawning and coordination.
|
||||||
|
torchrun automatically handles:
|
||||||
|
- Process spawning
|
||||||
|
- Environment variable setup
|
||||||
|
- Process group initialization
|
||||||
|
- Error handling and process cleanup
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, timeout: int = 5):
|
||||||
|
self._timeout = timeout
|
||||||
|
self._loop = asyncio.new_event_loop()
|
||||||
|
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
self._active_jobs: dict[JobID, asyncio.subprocess.Process] = {}
|
||||||
|
|
||||||
|
def _run_loop(self) -> None:
|
||||||
|
asyncio.set_event_loop(self._loop)
|
||||||
|
self._loop.run_forever()
|
||||||
|
|
||||||
|
# When stopping the loop, give tasks a chance to finish
|
||||||
|
for task in asyncio.all_tasks(self._loop):
|
||||||
|
self._loop.run_until_complete(task)
|
||||||
|
self._loop.close()
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
# Clean up any remaining processes
|
||||||
|
for process in self._active_jobs.values():
|
||||||
|
if process.returncode is None: # Process is still running
|
||||||
|
process.terminate()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(process.wait(), timeout=5)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
process.kill()
|
||||||
|
await process.wait()
|
||||||
|
|
||||||
|
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||||
|
self._thread.join()
|
||||||
|
|
||||||
|
def schedule(
|
||||||
|
self,
|
||||||
|
job: Job,
|
||||||
|
on_log_message_cb: Callable[[str], None],
|
||||||
|
on_status_change_cb: Callable[[JobStatus], None],
|
||||||
|
on_artifact_collected_cb: Callable[[JobArtifact], None],
|
||||||
|
) -> None:
|
||||||
|
async def do():
|
||||||
|
try:
|
||||||
|
job.status = JobStatus.running
|
||||||
|
|
||||||
|
# If this is a distributed training job, use torchrun
|
||||||
|
if job.world_size > 1:
|
||||||
|
# Find the path to finetune_handler.py
|
||||||
|
from llama_stack.providers.inline.post_training.huggingface import finetune_handler
|
||||||
|
|
||||||
|
handler_path = Path(finetune_handler.__file__)
|
||||||
|
|
||||||
|
# Prepare arguments for the handler script
|
||||||
|
args = [
|
||||||
|
"torchrun",
|
||||||
|
f"--nproc_per_node={job.world_size}",
|
||||||
|
"--master_addr=localhost",
|
||||||
|
"--master_port=29500",
|
||||||
|
str(handler_path),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add arguments from the job.run_args dictionary as proper command-line flags
|
||||||
|
for arg_name, arg_value in job.run_args.items():
|
||||||
|
# Skip world_size as we've already handled it
|
||||||
|
if arg_name == "world_size":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if arg_value is not None:
|
||||||
|
# Handle boolean flags
|
||||||
|
if isinstance(arg_value, bool):
|
||||||
|
if arg_value:
|
||||||
|
args.append(f"--{arg_name}")
|
||||||
|
else:
|
||||||
|
# For non-boolean values, we add the argument as a separate flag and value
|
||||||
|
args.append(f"--{arg_name}")
|
||||||
|
args.append(str(arg_value))
|
||||||
|
|
||||||
|
# Launch torchrun using asyncio
|
||||||
|
on_log_message_cb(f"Launching distributed training with {job.world_size} processes")
|
||||||
|
on_log_message_cb(f"Command: {' '.join(args)}")
|
||||||
|
|
||||||
|
# Make sure we capture stdout and stderr
|
||||||
|
process = await asyncio.create_subprocess_exec(
|
||||||
|
*args,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.STDOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store process for this job
|
||||||
|
self._active_jobs[job.id] = process
|
||||||
|
|
||||||
|
# Start monitoring in a separate task so we don't block
|
||||||
|
asyncio.create_task(
|
||||||
|
self._monitor_process(job, process, None, on_log_message_cb, on_status_change_cb)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For single-device training, call the handler directly if provided
|
||||||
|
if job.handler:
|
||||||
|
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
|
||||||
|
job.status = JobStatus.completed
|
||||||
|
else:
|
||||||
|
on_log_message_cb("No handler function provided for single-device training")
|
||||||
|
job.status = JobStatus.failed
|
||||||
|
except Exception as e:
|
||||||
|
on_log_message_cb(str(e))
|
||||||
|
job.status = JobStatus.failed
|
||||||
|
logger.exception(f"Job {job.id} failed.")
|
||||||
|
|
||||||
|
asyncio.run_coroutine_threadsafe(do(), self._loop)
|
||||||
|
|
||||||
|
async def _monitor_process(
|
||||||
|
self,
|
||||||
|
job: Job,
|
||||||
|
process: asyncio.subprocess.Process,
|
||||||
|
script_path: Path | None,
|
||||||
|
on_log_message_cb: Callable[[str], None],
|
||||||
|
on_status_change_cb: Callable[[JobStatus], None],
|
||||||
|
) -> None:
|
||||||
|
"""Monitor a process until completion."""
|
||||||
|
try:
|
||||||
|
# Stream output from the process if stdout is available
|
||||||
|
if process.stdout is not None:
|
||||||
|
while True:
|
||||||
|
line = await process.stdout.readline()
|
||||||
|
if not line and process.returncode is not None:
|
||||||
|
break
|
||||||
|
if line:
|
||||||
|
on_log_message_cb(line.decode().strip())
|
||||||
|
else:
|
||||||
|
# If stdout is not available, just wait for the process to complete
|
||||||
|
on_log_message_cb("Process stdout not available, waiting for completion")
|
||||||
|
await process.wait()
|
||||||
|
|
||||||
|
# Wait for process to complete if not already done
|
||||||
|
if process.returncode is None:
|
||||||
|
await process.wait()
|
||||||
|
|
||||||
|
# Check if process failed
|
||||||
|
if process.returncode != 0:
|
||||||
|
on_log_message_cb(f"Training failed with return code {process.returncode}")
|
||||||
|
job.status = JobStatus.failed
|
||||||
|
else:
|
||||||
|
on_status_change_cb(JobStatus.completed)
|
||||||
|
job.status = JobStatus.completed
|
||||||
|
except Exception as e:
|
||||||
|
on_log_message_cb(f"Error monitoring process: {str(e)}")
|
||||||
|
job.status = JobStatus.failed
|
||||||
|
logger.exception(f"Error monitoring process for job {job.id}")
|
||||||
|
finally:
|
||||||
|
# Clean up temporary files
|
||||||
|
if script_path and script_path.exists():
|
||||||
|
script_path.unlink()
|
||||||
|
|
||||||
|
# Remove from active jobs
|
||||||
|
if job.id in self._active_jobs:
|
||||||
|
del self._active_jobs[job.id]
|
||||||
|
|
||||||
|
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
_BACKENDS = {
|
_BACKENDS = {
|
||||||
"naive": _NaiveSchedulerBackend,
|
"naive": _NaiveSchedulerBackend,
|
||||||
|
"distributed": DistributedJobScheduler,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -230,11 +429,18 @@ class Scheduler:
|
||||||
job.register_artifact(artifact)
|
job.register_artifact(artifact)
|
||||||
self._backend.on_artifact_collected_cb(job, artifact)
|
self._backend.on_artifact_collected_cb(job, artifact)
|
||||||
|
|
||||||
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID:
|
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler | None, run_params: dict[str, Any]) -> JobID:
|
||||||
job = Job(type_, job_id, handler)
|
job = Job(type_, job_id, handler)
|
||||||
if job.id in self._jobs:
|
if job.id in self._jobs:
|
||||||
raise ValueError(f"Job {job.id} already exists")
|
raise ValueError(f"Job {job.id} already exists")
|
||||||
|
|
||||||
|
# Set world size if provided
|
||||||
|
if "world_size" in run_params:
|
||||||
|
job.world_size = run_params["world_size"]
|
||||||
|
|
||||||
|
# Store all run parameters in the job's run_args dictionary
|
||||||
|
job.run_args = run_params
|
||||||
|
|
||||||
self._jobs[job.id] = job
|
self._jobs[job.id] = job
|
||||||
job.status = JobStatus.scheduled
|
job.status = JobStatus.scheduled
|
||||||
self._backend.schedule(
|
self._backend.schedule(
|
||||||
|
|
|
@ -100,6 +100,7 @@ providers:
|
||||||
checkpoint_format: huggingface
|
checkpoint_format: huggingface
|
||||||
distributed_backend: null
|
distributed_backend: null
|
||||||
device: cpu
|
device: cpu
|
||||||
|
recipe: single
|
||||||
tool_runtime:
|
tool_runtime:
|
||||||
- provider_id: brave-search
|
- provider_id: brave-search
|
||||||
provider_type: remote::brave-search
|
provider_type: remote::brave-search
|
||||||
|
|
|
@ -98,6 +98,7 @@ providers:
|
||||||
checkpoint_format: huggingface
|
checkpoint_format: huggingface
|
||||||
distributed_backend: null
|
distributed_backend: null
|
||||||
device: cpu
|
device: cpu
|
||||||
|
recipe: single
|
||||||
tool_runtime:
|
tool_runtime:
|
||||||
- provider_id: brave-search
|
- provider_id: brave-search
|
||||||
provider_type: remote::brave-search
|
provider_type: remote::brave-search
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue