mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-16 18:08:09 +00:00
feat: DistributedJobScheduler
rather than handling multi-GPU training within a recipe, distributed training should be one of our scheduler offerings. Introduce the DistributedJobScheduler which kicks off a `finetune_handler.py` script using torchrun. This handler processes the training args via argparse and calls the right recipe as `post_training.py` used to do. Torchrun takes care of env variables like world_size, local_rank, etc. Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
6494658a10
commit
ce48d47543
5 changed files with 534 additions and 143 deletions
174
llama_stack/providers/inline/post_training/huggingface/finetune_handler.py
Executable file
174
llama_stack/providers/inline/post_training/huggingface/finetune_handler.py
Executable file
|
@ -0,0 +1,174 @@
|
|||
#!/usr/bin/env python
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.post_training import TrainingConfig
|
||||
from llama_stack.providers.inline.post_training.huggingface.config import HuggingFacePostTrainingConfig
|
||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_multi_device import (
|
||||
HFFinetuningMultiDevice,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobStatus
|
||||
|
||||
|
||||
async def train(
|
||||
job_uuid,
|
||||
model,
|
||||
checkpoint_dir,
|
||||
training_config,
|
||||
provider_config,
|
||||
algorithm_config,
|
||||
data,
|
||||
enable_nccl_debug=False,
|
||||
nccl_debug_subsys="NONE",
|
||||
):
|
||||
"""Handler function for HuggingFace training that can be called by torchrun.
|
||||
|
||||
This is extracted from the supervised_fine_tune method in the HuggingFacePostTrainingImpl class.
|
||||
It follows the same flow, but is designed to be called directly from a script.
|
||||
|
||||
Args:
|
||||
job_uuid: Unique ID for this job
|
||||
model: Model to train
|
||||
checkpoint_dir: Directory to save checkpoints to
|
||||
training_config: Training configuration
|
||||
provider_config: Provider configuration
|
||||
algorithm_config: Algorithm configuration
|
||||
data: the dataset rows to be processed
|
||||
enable_nccl_debug: Whether to enable NCCL debugging
|
||||
nccl_debug_subsys: NCCL subsystem to debug
|
||||
"""
|
||||
# Get rank information when running distributed
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
|
||||
parsed_data: list[dict[str, Any]] = json.loads(data)
|
||||
|
||||
# Set up callback functions with rank information
|
||||
def on_log_message_cb(msg):
|
||||
print(f"[RANK {local_rank}] {msg}", flush=True)
|
||||
|
||||
def on_status_change_cb(status):
|
||||
print(f"[RANK {local_rank}] Status: {status}", flush=True)
|
||||
|
||||
def on_artifact_collected_cb(artifact):
|
||||
print(f"[RANK {local_rank}] Artifact: {artifact}", flush=True)
|
||||
|
||||
on_log_message_cb("Starting HF finetuning")
|
||||
|
||||
recipe_obj = HFFinetuningMultiDevice(
|
||||
job_uuid=job_uuid, enable_nccl_debug=enable_nccl_debug, nccl_debug_subsys=nccl_debug_subsys, data=parsed_data
|
||||
)
|
||||
|
||||
resources_allocated, checkpoints = await recipe_obj.train(
|
||||
model=model,
|
||||
output_dir=checkpoint_dir,
|
||||
job_uuid=job_uuid,
|
||||
lora_config=algorithm_config,
|
||||
config=training_config,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
|
||||
def resources_stats_to_artifact(resources_stats):
|
||||
return {
|
||||
"type": "resources_stats",
|
||||
"name": "resources_stats",
|
||||
"metadata": resources_stats,
|
||||
}
|
||||
|
||||
def checkpoint_to_artifact(checkpoint):
|
||||
return {
|
||||
"type": "checkpoint",
|
||||
"name": checkpoint.identifier,
|
||||
"uri": checkpoint.path,
|
||||
"metadata": dict(checkpoint),
|
||||
}
|
||||
|
||||
on_artifact_collected_cb(resources_stats_to_artifact(resources_allocated))
|
||||
if checkpoints:
|
||||
for checkpoint in checkpoints:
|
||||
artifact = checkpoint_to_artifact(checkpoint)
|
||||
on_artifact_collected_cb(artifact)
|
||||
|
||||
on_status_change_cb(JobStatus.completed)
|
||||
on_log_message_cb("HF finetuning completed")
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description="Run HuggingFace training with torchrun.")
|
||||
parser.add_argument("--job_uuid", type=str, required=True, help="Job UUID")
|
||||
parser.add_argument("--model", type=str, required=True, help="Model to use")
|
||||
parser.add_argument("--checkpoint_dir", type=str, help="Directory to save checkpoints")
|
||||
parser.add_argument("--training_config", type=str, required=True, help="Training config JSON")
|
||||
parser.add_argument("--provider_config", type=str, required=True, help="Provider config JSON")
|
||||
parser.add_argument("--algorithm_config", type=str, help="Algorithm config JSON")
|
||||
parser.add_argument("--enable_nccl_debug", action="store_true", help="Enable NCCL debugging")
|
||||
parser.add_argument("--nccl_debug_subsys", type=str, default="NONE", help="NCCL subsystem to debug")
|
||||
parser.add_argument("--data", type=str, required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse JSON configs
|
||||
try:
|
||||
training_config = TrainingConfig.model_validate_json(args.training_config)
|
||||
except Exception as e:
|
||||
print(f"Error parsing training_config: {e}")
|
||||
print(f"Received: {args.training_config}")
|
||||
raise
|
||||
|
||||
try:
|
||||
provider_config = HuggingFacePostTrainingConfig.model_validate_json(args.provider_config)
|
||||
except Exception as e:
|
||||
print(f"Error parsing provider_config: {e}")
|
||||
print(f"Received: {args.provider_config}")
|
||||
raise
|
||||
|
||||
algorithm_config = None
|
||||
if args.algorithm_config:
|
||||
try:
|
||||
algorithm_config = json.loads(args.algorithm_config)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing algorithm_config: {e}")
|
||||
print(f"Received: {args.algorithm_config}")
|
||||
raise
|
||||
|
||||
# In a real implementation, you would get these from somewhere
|
||||
# For now, we'll pass None and handle it in the train function
|
||||
datasetio_api = None
|
||||
datasets_api = None
|
||||
|
||||
# Print arguments for debugging
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||
if local_rank == 0: # Only the main process prints
|
||||
print("Starting training with arguments:")
|
||||
print(f" job_uuid: {args.job_uuid}")
|
||||
print(f" model: {args.model}")
|
||||
print(f" checkpoint_dir: {args.checkpoint_dir}")
|
||||
print(f" enable_nccl_debug: {args.enable_nccl_debug}")
|
||||
print(f" nccl_debug_subsys: {args.nccl_debug_subsys}")
|
||||
|
||||
await train(
|
||||
job_uuid=args.job_uuid,
|
||||
model=args.model,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
training_config=training_config,
|
||||
provider_config=provider_config,
|
||||
algorithm_config=algorithm_config,
|
||||
datasetio_api=datasetio_api,
|
||||
datasets_api=datasets_api,
|
||||
enable_nccl_debug=args.enable_nccl_debug,
|
||||
nccl_debug_subsys=args.nccl_debug_subsys,
|
||||
data=args.data,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
|
@ -3,6 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
@ -81,43 +82,130 @@ class HuggingFacePostTrainingImpl:
|
|||
checkpoint_dir: str | None = None,
|
||||
algorithm_config: AlgorithmConfig | None = None,
|
||||
) -> PostTrainingJob:
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
on_log_message_cb("Starting HF finetuning")
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
recipe = HFFinetuningSingleDevice(
|
||||
job_uuid=job_uuid,
|
||||
datasetio_api=self.datasetio_api,
|
||||
datasets_api=self.datasets_api,
|
||||
)
|
||||
if self.config.recipe == "multi":
|
||||
recipe = HFFinetuningMultiDevice(
|
||||
# Type for the handler: async fn taking 3 Any args, returns Awaitable[None]
|
||||
handler: (
|
||||
Callable[
|
||||
[Callable[[str], None], Callable[[SchedulerJobStatus], None], Callable[[JobArtifact], None]],
|
||||
Coroutine[Any, Any, None],
|
||||
]
|
||||
| None
|
||||
) = None
|
||||
|
||||
# Determine world size for distributed training
|
||||
world_size = getattr(self.config, "world_size", 1)
|
||||
|
||||
# Choose the backend and recipe based on world size
|
||||
if world_size > 1:
|
||||
recipe = "multi"
|
||||
|
||||
# Create parameters for the handler script
|
||||
run_params = {
|
||||
"job_uuid": job_uuid,
|
||||
"model": model,
|
||||
"world_size": world_size,
|
||||
"recipe": recipe,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if checkpoint_dir is not None:
|
||||
run_params["checkpoint_dir"] = checkpoint_dir
|
||||
|
||||
if training_config is not None:
|
||||
run_params["training_config"] = training_config.model_dump_json()
|
||||
|
||||
if algorithm_config is not None:
|
||||
run_params["algorithm_config"] = algorithm_config.model_dump_json()
|
||||
|
||||
# Add provider-specific configuration
|
||||
run_params["provider_config"] = self.config.model_dump_json()
|
||||
|
||||
# Add NCCL debug settings if present
|
||||
if hasattr(self.config, "enable_nccl_debug"):
|
||||
run_params["enable_nccl_debug"] = self.config.enable_nccl_debug
|
||||
|
||||
if hasattr(self.config, "nccl_debug_subsys"):
|
||||
run_params["nccl_debug_subsys"] = self.config.nccl_debug_subsys
|
||||
|
||||
# Initialize the scheduler with the distributed backend
|
||||
self._scheduler = Scheduler(backend="distributed")
|
||||
else:
|
||||
self._scheduler = Scheduler(backend="naive")
|
||||
|
||||
# TODO: this can probably be cleaner
|
||||
# Single-device training path
|
||||
# Define a handler for single-device training
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
on_log_message_cb("Starting HF finetuning")
|
||||
|
||||
recipe = HFFinetuningSingleDevice(
|
||||
job_uuid=job_uuid,
|
||||
datasetio_api=self.datasetio_api,
|
||||
datasets_api=self.datasets_api,
|
||||
enable_nccl_debug=self.config.enable_nccl_debug,
|
||||
nccl_debug_subsys=self.config.nccl_debug_subsys,
|
||||
)
|
||||
if self.config.recipe == "multi":
|
||||
recipe = HFFinetuningMultiDevice(
|
||||
job_uuid=job_uuid,
|
||||
datasetio_api=self.datasetio_api,
|
||||
datasets_api=self.datasets_api,
|
||||
enable_nccl_debug=getattr(self.config, "enable_nccl_debug", False),
|
||||
nccl_debug_subsys=getattr(self.config, "nccl_debug_subsys", "NONE"),
|
||||
)
|
||||
|
||||
resources_allocated, checkpoints = await recipe.train(
|
||||
model=model,
|
||||
output_dir=checkpoint_dir,
|
||||
job_uuid=job_uuid,
|
||||
lora_config=algorithm_config,
|
||||
config=training_config,
|
||||
provider_config=self.config,
|
||||
)
|
||||
|
||||
resources_allocated, checkpoints = await recipe.train(
|
||||
model=model,
|
||||
output_dir=checkpoint_dir,
|
||||
job_uuid=job_uuid,
|
||||
lora_config=algorithm_config,
|
||||
config=training_config,
|
||||
provider_config=self.config,
|
||||
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
|
||||
if checkpoints:
|
||||
for checkpoint in checkpoints:
|
||||
artifact = self._checkpoint_to_artifact(checkpoint)
|
||||
on_artifact_collected_cb(artifact)
|
||||
|
||||
on_status_change_cb(SchedulerJobStatus.completed)
|
||||
on_log_message_cb("HF finetuning completed")
|
||||
|
||||
assert training_config.data_config is not None
|
||||
data = self._setup_data(dataset_id=training_config.data_config.dataset_id)
|
||||
|
||||
json_data = json.dumps(data)
|
||||
|
||||
run_params["data"] = json_data
|
||||
|
||||
# Schedule the job with the regular scheduler and the handler
|
||||
job_id = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler, run_params)
|
||||
|
||||
return PostTrainingJob(job_uuid=job_id)
|
||||
|
||||
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
|
||||
"""Load dataset from llama stack dataset provider.
|
||||
|
||||
Args:
|
||||
dataset_id: ID of the dataset to load
|
||||
|
||||
Returns:
|
||||
list: List of dataset rows
|
||||
|
||||
Raises:
|
||||
RuntimeError: If dataset loading fails
|
||||
"""
|
||||
try:
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
limit=-1,
|
||||
)
|
||||
|
||||
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
|
||||
if checkpoints:
|
||||
for checkpoint in checkpoints:
|
||||
artifact = self._checkpoint_to_artifact(checkpoint)
|
||||
on_artifact_collected_cb(artifact)
|
||||
|
||||
on_status_change_cb(SchedulerJobStatus.completed)
|
||||
on_log_message_cb("HF finetuning completed")
|
||||
|
||||
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
||||
return PostTrainingJob(job_uuid=job_uuid)
|
||||
if not isinstance(all_rows.data, list):
|
||||
raise RuntimeError("Expected dataset data to be a list")
|
||||
return all_rows.data
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
|
||||
|
||||
async def preference_optimize(
|
||||
self,
|
||||
|
|
|
@ -75,8 +75,6 @@ from transformers import (
|
|||
)
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
Checkpoint,
|
||||
DataConfig,
|
||||
|
@ -191,8 +189,7 @@ class HFFinetuningMultiDevice:
|
|||
def __init__(
|
||||
self,
|
||||
job_uuid: str,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
data: list[dict[str, Any]],
|
||||
enable_nccl_debug: bool = False,
|
||||
nccl_debug_subsys: str = "NONE",
|
||||
):
|
||||
|
@ -203,8 +200,7 @@ class HFFinetuningMultiDevice:
|
|||
datasetio_api: API for dataset I/O operations
|
||||
datasets_api: API for dataset management
|
||||
"""
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.data = data
|
||||
self.job_uuid = job_uuid
|
||||
self.enable_nccl_debug = enable_nccl_debug
|
||||
self.nccl_debug_subsys = nccl_debug_subsys
|
||||
|
@ -408,29 +404,6 @@ class HFFinetuningMultiDevice:
|
|||
num_proc=1, # Single process to avoid issues
|
||||
)
|
||||
|
||||
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
|
||||
"""Load dataset from llama stack dataset provider.
|
||||
|
||||
Args:
|
||||
dataset_id: ID of the dataset to load
|
||||
|
||||
Returns:
|
||||
list: List of dataset rows
|
||||
|
||||
Raises:
|
||||
RuntimeError: If dataset loading fails
|
||||
"""
|
||||
try:
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
limit=-1,
|
||||
)
|
||||
if not isinstance(all_rows.data, list):
|
||||
raise RuntimeError("Expected dataset data to be a list")
|
||||
return all_rows.data
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
|
||||
|
||||
def _run_training_sync(
|
||||
self,
|
||||
local_rank: int, # First parameter must be local_rank for spawn
|
||||
|
@ -627,10 +600,9 @@ class HFFinetuningMultiDevice:
|
|||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||
rows = await self._setup_data(config.data_config.dataset_id)
|
||||
if not self.validate_dataset_format(rows):
|
||||
if not self.validate_dataset_format(self.data):
|
||||
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
||||
logger.info(f"Loaded {len(rows)} rows from dataset")
|
||||
logger.info(f"Loaded {len(self.data)} rows from dataset")
|
||||
|
||||
# Initialize tokenizer
|
||||
logger.info(f"Initializing tokenizer for model: {model}")
|
||||
|
@ -662,7 +634,7 @@ class HFFinetuningMultiDevice:
|
|||
# Create and preprocess dataset
|
||||
logger.info("Creating and preprocessing dataset")
|
||||
try:
|
||||
ds = self._create_dataset(rows, config, provider_config)
|
||||
ds = self._create_dataset(self.data, config, provider_config)
|
||||
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
|
||||
logger.info(f"Dataset created with {len(ds)} examples")
|
||||
except Exception as e:
|
||||
|
@ -1021,37 +993,16 @@ class HFFinetuningMultiDevice:
|
|||
config: TrainingConfig,
|
||||
provider_config: HuggingFacePostTrainingConfig,
|
||||
) -> tuple[dict[str, Any], list[Checkpoint] | None]:
|
||||
"""Train a model using HuggingFace's SFTTrainer with distributed training.
|
||||
|
||||
The distributed training setup works as follows:
|
||||
1. Parse the device list to determine number of GPUs
|
||||
2. Use torch.multiprocessing.spawn to launch one process per GPU
|
||||
3. Each process runs _run_training_sync with a unique rank
|
||||
4. The processes coordinate through NCCL backend
|
||||
5. FSDP handles model sharding across GPUs
|
||||
6. Only rank 0 handles saving checkpoints and logging
|
||||
|
||||
Args:
|
||||
model: The model identifier to load
|
||||
output_dir: Optional directory to save checkpoints
|
||||
job_uuid: Unique identifier for this training job
|
||||
lora_config: LoRA configuration for parameter-efficient fine-tuning
|
||||
config: General training configuration
|
||||
provider_config: Provider-specific configuration
|
||||
Returns:
|
||||
tuple: (memory_stats, checkpoints)
|
||||
"""
|
||||
|
||||
"""Train a model using HuggingFace's SFTTrainer with distributed training."""
|
||||
if provider_config.distributed_backend != "fsdp":
|
||||
raise RuntimeError("Must enable FSDP as distributed backend to use this recipe")
|
||||
|
||||
# Configure NCCL logging based on debug settings
|
||||
configure_nccl_logging(self.enable_nccl_debug, self.nccl_debug_subsys)
|
||||
|
||||
# Parse device list to determine number of GPUs
|
||||
devices = [d.strip() for d in provider_config.device.split(",")]
|
||||
world_size = len(devices)
|
||||
logger.info(f"Using {world_size} devices: {devices}")
|
||||
# Get local rank and world size from environment variables
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
|
||||
output_dir_path = None
|
||||
if output_dir:
|
||||
|
@ -1081,32 +1032,22 @@ class HFFinetuningMultiDevice:
|
|||
raise ValueError("DataConfig is required for training")
|
||||
|
||||
try:
|
||||
# Launch distributed training processes
|
||||
# torch.multiprocessing.spawn will:
|
||||
# 1. Create world_size number of processes
|
||||
# 2. Call _run_training_sync for each process
|
||||
# 3. Pass unique local_rank to each process
|
||||
# 4. Handle process coordination and cleanup
|
||||
logger.info("Starting distributed training processes")
|
||||
torch.multiprocessing.spawn(
|
||||
self._run_training_sync,
|
||||
args=(
|
||||
world_size,
|
||||
model,
|
||||
provider_config.model_dump(),
|
||||
peft_config,
|
||||
config.model_dump(),
|
||||
output_dir_path,
|
||||
),
|
||||
nprocs=world_size,
|
||||
join=True, # Wait for all processes to complete
|
||||
# Run training for this process
|
||||
await self._run_training(
|
||||
model=model,
|
||||
provider_config=provider_config.model_dump(),
|
||||
peft_config=peft_config,
|
||||
config=config.model_dump(),
|
||||
output_dir_path=output_dir_path,
|
||||
local_rank=local_rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
memory_stats["after_training"] = get_memory_stats(torch.device("cuda:0"))
|
||||
|
||||
# Create checkpoint on rank 0
|
||||
# Only create checkpoint on rank 0
|
||||
checkpoints = None
|
||||
if output_dir_path:
|
||||
if output_dir_path and local_rank == 0:
|
||||
checkpoint = Checkpoint(
|
||||
identifier=f"{model}-sft-{config.n_epochs}",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
|
|
|
@ -37,8 +37,6 @@ from transformers import (
|
|||
)
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
Checkpoint,
|
||||
DataConfig,
|
||||
|
@ -136,11 +134,9 @@ class HFFinetuningSingleDevice:
|
|||
def __init__(
|
||||
self,
|
||||
job_uuid: str,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
data: list[dict[str, Any]],
|
||||
):
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.data = data
|
||||
self.job_uuid = job_uuid
|
||||
|
||||
def validate_dataset_format(self, rows: list[dict]) -> bool:
|
||||
|
@ -262,19 +258,6 @@ class HFFinetuningSingleDevice:
|
|||
remove_columns=ds.column_names,
|
||||
)
|
||||
|
||||
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
|
||||
"""Load dataset from llama stack dataset provider"""
|
||||
try:
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
limit=-1,
|
||||
)
|
||||
if not isinstance(all_rows.data, list):
|
||||
raise RuntimeError("Expected dataset data to be a list")
|
||||
return all_rows.data
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
|
||||
|
||||
def _run_training_sync(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -327,10 +310,9 @@ class HFFinetuningSingleDevice:
|
|||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||
rows = await self._setup_data(config.data_config.dataset_id)
|
||||
if not self.validate_dataset_format(rows):
|
||||
if not self.validate_dataset_format(self.data):
|
||||
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
||||
logger.info(f"Loaded {len(rows)} rows from dataset")
|
||||
logger.info(f"Loaded {len(self.data)} rows from dataset")
|
||||
|
||||
# Initialize tokenizer
|
||||
logger.info(f"Initializing tokenizer for model: {model}")
|
||||
|
@ -362,7 +344,7 @@ class HFFinetuningSingleDevice:
|
|||
# Create and preprocess dataset
|
||||
logger.info("Creating and preprocessing dataset")
|
||||
try:
|
||||
ds = self._create_dataset(rows, config, provider_config)
|
||||
ds = self._create_dataset(self.data, config, provider_config)
|
||||
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
|
||||
logger.info(f"Dataset created with {len(ds)} examples")
|
||||
except Exception as e:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue