diff --git a/llama_stack/providers/inline/post_training/huggingface/__init__.py b/llama_stack/providers/inline/post_training/huggingface/__init__.py new file mode 100644 index 000000000..cc1a671c1 --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.distribution.datatypes import Api + +from .config import HuggingFacePostTrainingConfig + +# post_training api and the huggingface provider is still experimental and under heavy development + + +async def get_provider_impl( + config: HuggingFacePostTrainingConfig, + deps: dict[Api, Any], +): + from .post_training import HuggingFacePostTrainingImpl + + impl = HuggingFacePostTrainingImpl( + config, + deps[Api.datasetio], + deps[Api.datasets], + ) + return impl diff --git a/llama_stack/providers/inline/post_training/huggingface/config.py b/llama_stack/providers/inline/post_training/huggingface/config.py new file mode 100644 index 000000000..06c6d8073 --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface/config.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Literal + +from pydantic import BaseModel + + +class HuggingFacePostTrainingConfig(BaseModel): + # Device to run training on (cuda, cpu, mps) + device: str = "cuda" + + # Distributed training backend if using multiple devices + # fsdp: Fully Sharded Data Parallel + # deepspeed: DeepSpeed ZeRO optimization + distributed_backend: Literal["fsdp", "deepspeed"] | None = None + + # Format for saving model checkpoints + # full_state: Save complete model state + # huggingface: Save in HuggingFace format (recommended for compatibility) + checkpoint_format: Literal["full_state", "huggingface"] | None = "huggingface" + + # Template for formatting chat inputs and outputs + # Used to structure the conversation format for training + chat_template: str = "<|user|>\n{input}\n<|assistant|>\n{output}" + + # Model-specific configuration parameters + # trust_remote_code: Allow execution of custom model code + # attn_implementation: Use SDPA (Scaled Dot Product Attention) for better performance + model_specific_config: dict = { + "trust_remote_code": True, + "attn_implementation": "sdpa", + } + + # Maximum sequence length for training + # Set to 2048 as this is the maximum that works reliably on MPS (Apple Silicon) + # Longer sequences may cause memory issues on MPS devices + max_seq_length: int = 2048 + + # Enable gradient checkpointing to reduce memory usage + # Trades computation for memory by recomputing activations + gradient_checkpointing: bool = False + + # Maximum number of checkpoints to keep + # Older checkpoints are deleted when this limit is reached + save_total_limit: int = 3 + + # Number of training steps between logging updates + logging_steps: int = 10 + + # Ratio of training steps used for learning rate warmup + # Helps stabilize early training + warmup_ratio: float = 0.1 + + # L2 regularization coefficient + # Helps prevent overfitting + weight_decay: float = 0.01 + + # Number of worker processes for data loading + # Higher values can improve data loading speed but increase memory usage + dataloader_num_workers: int = 4 + + # Whether to pin memory in data loader + # Can improve data transfer speed to GPU but uses more memory + dataloader_pin_memory: bool = True + + @classmethod + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: + return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"} diff --git a/llama_stack/providers/inline/post_training/huggingface/post_training.py b/llama_stack/providers/inline/post_training/huggingface/post_training.py new file mode 100644 index 000000000..0b2760792 --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface/post_training.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from enum import Enum +from typing import Any + +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.post_training import ( + AlgorithmConfig, + Checkpoint, + DPOAlignmentConfig, + JobStatus, + ListPostTrainingJobsResponse, + PostTrainingJob, + PostTrainingJobArtifactsResponse, + PostTrainingJobStatusResponse, + TrainingConfig, +) +from llama_stack.providers.inline.post_training.huggingface.config import ( + HuggingFacePostTrainingConfig, +) +from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import ( + HFFinetuningSingleDevice, +) +from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler +from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus +from llama_stack.schema_utils import webmethod + + +class TrainingArtifactType(Enum): + CHECKPOINT = "checkpoint" + RESOURCES_STATS = "resources_stats" + + +_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune" + + +class HuggingFacePostTrainingImpl: + def __init__( + self, + config: HuggingFacePostTrainingConfig, + datasetio_api: DatasetIO, + datasets: Datasets, + ) -> None: + self.config = config + self.datasetio_api = datasetio_api + self.datasets_api = datasets + self._scheduler = Scheduler() + + async def shutdown(self) -> None: + await self._scheduler.shutdown() + + @staticmethod + def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact: + return JobArtifact( + type=TrainingArtifactType.CHECKPOINT.value, + name=checkpoint.identifier, + uri=checkpoint.path, + metadata=dict(checkpoint), + ) + + @staticmethod + def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact: + return JobArtifact( + type=TrainingArtifactType.RESOURCES_STATS.value, + name=TrainingArtifactType.RESOURCES_STATS.value, + metadata=resources_stats, + ) + + async def supervised_fine_tune( + self, + job_uuid: str, + training_config: TrainingConfig, + hyperparam_search_config: dict[str, Any], + logger_config: dict[str, Any], + model: str, + checkpoint_dir: str | None = None, + algorithm_config: AlgorithmConfig | None = None, + ) -> PostTrainingJob: + async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb): + on_log_message_cb("Starting HF finetuning") + + recipe = HFFinetuningSingleDevice( + job_uuid=job_uuid, + datasetio_api=self.datasetio_api, + datasets_api=self.datasets_api, + ) + + resources_allocated, checkpoints = await recipe.train( + model=model, + output_dir=checkpoint_dir, + job_uuid=job_uuid, + lora_config=algorithm_config, + config=training_config, + provider_config=self.config, + ) + + on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated)) + if checkpoints: + for checkpoint in checkpoints: + artifact = self._checkpoint_to_artifact(checkpoint) + on_artifact_collected_cb(artifact) + + on_status_change_cb(SchedulerJobStatus.completed) + on_log_message_cb("HF finetuning completed") + + job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler) + return PostTrainingJob(job_uuid=job_uuid) + + async def preference_optimize( + self, + job_uuid: str, + finetuned_model: str, + algorithm_config: DPOAlignmentConfig, + training_config: TrainingConfig, + hyperparam_search_config: dict[str, Any], + logger_config: dict[str, Any], + ) -> PostTrainingJob: + raise NotImplementedError("DPO alignment is not implemented yet") + + async def get_training_jobs(self) -> ListPostTrainingJobsResponse: + return ListPostTrainingJobsResponse( + data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()] + ) + + @staticmethod + def _get_artifacts_metadata_by_type(job, artifact_type): + return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type] + + @classmethod + def _get_checkpoints(cls, job): + return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value) + + @classmethod + def _get_resources_allocated(cls, job): + data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value) + return data[0] if data else None + + @webmethod(route="/post-training/job/status") + async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None: + job = self._scheduler.get_job(job_uuid) + + match job.status: + # TODO: Add support for other statuses to API + case SchedulerJobStatus.new | SchedulerJobStatus.scheduled: + status = JobStatus.scheduled + case SchedulerJobStatus.running: + status = JobStatus.in_progress + case SchedulerJobStatus.completed: + status = JobStatus.completed + case SchedulerJobStatus.failed: + status = JobStatus.failed + case _: + raise NotImplementedError() + + return PostTrainingJobStatusResponse( + job_uuid=job_uuid, + status=status, + scheduled_at=job.scheduled_at, + started_at=job.started_at, + completed_at=job.completed_at, + checkpoints=self._get_checkpoints(job), + resources_allocated=self._get_resources_allocated(job), + ) + + @webmethod(route="/post-training/job/cancel") + async def cancel_training_job(self, job_uuid: str) -> None: + self._scheduler.cancel(job_uuid) + + @webmethod(route="/post-training/job/artifacts") + async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None: + job = self._scheduler.get_job(job_uuid) + return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job)) diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py new file mode 100644 index 000000000..fd1c68655 --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py @@ -0,0 +1,502 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import gc +import json +import logging +import os +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import psutil + +# Set tokenizer parallelism environment variable +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +import torch +from datasets import Dataset +from peft import LoraConfig +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, +) +from trl import SFTConfig, SFTTrainer + +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.post_training import ( + Checkpoint, + DataConfig, + LoraFinetuningConfig, + TrainingConfig, +) + +from ..config import HuggingFacePostTrainingConfig + +logger = logging.getLogger(__name__) + + +def get_gb(to_convert: int) -> str: + """Converts memory stats to GB and formats to 2 decimal places. + Args: + to_convert: Memory value in bytes + Returns: + str: Memory value in GB formatted to 2 decimal places + """ + return f"{(to_convert / (1024**3)):.2f}" + + +def get_memory_stats(device: torch.device) -> dict[str, Any]: + """Get memory statistics for the given device.""" + stats = { + "system_memory": { + "total": get_gb(psutil.virtual_memory().total), + "available": get_gb(psutil.virtual_memory().available), + "used": get_gb(psutil.virtual_memory().used), + "percent": psutil.virtual_memory().percent, + } + } + + if device.type == "cuda": + stats["device_memory"] = { + "allocated": get_gb(torch.cuda.memory_allocated(device)), + "reserved": get_gb(torch.cuda.memory_reserved(device)), + "max_allocated": get_gb(torch.cuda.max_memory_allocated(device)), + } + elif device.type == "mps": + # MPS doesn't provide direct memory stats, but we can track system memory + stats["device_memory"] = { + "note": "MPS memory stats not directly available", + "system_memory_used": get_gb(psutil.virtual_memory().used), + } + elif device.type == "cpu": + # For CPU, we track process memory usage + process = psutil.Process() + stats["device_memory"] = { + "process_rss": get_gb(process.memory_info().rss), + "process_vms": get_gb(process.memory_info().vms), + "process_percent": process.memory_percent(), + } + + return stats + + +class HFFinetuningSingleDevice: + def __init__( + self, + job_uuid, + datasetio_api: DatasetIO, + datasets_api: Datasets, + ): + self.datasetio_api = datasetio_api + self.datasets_api = datasets_api + self.job_uuid = job_uuid + + def validate_dataset_format(self, rows: list[dict]) -> bool: + """Validate that the dataset has the required fields.""" + required_fields = ["input_query", "expected_answer", "chat_completion_input"] + return all(field in row for row in rows for field in required_fields) + + def _process_instruct_format(self, row: dict) -> tuple[str | None, str | None]: + """Process a row in instruct format.""" + if "chat_completion_input" in row and "expected_answer" in row: + try: + messages = json.loads(row["chat_completion_input"]) + if not isinstance(messages, list) or len(messages) != 1: + logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}") + return None, None + if "content" not in messages[0]: + logger.warning(f"Message missing content: {messages[0]}") + return None, None + return messages[0]["content"], row["expected_answer"] + except json.JSONDecodeError: + logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}") + return None, None + return None, None + + def _process_dialog_format(self, row: dict) -> tuple[str | None, str | None]: + """Process a row in dialog format.""" + if "dialog" in row: + try: + dialog = json.loads(row["dialog"]) + if not isinstance(dialog, list) or len(dialog) < 2: + logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}") + return None, None + if dialog[0].get("role") != "user": + logger.warning(f"First message must be from user: {dialog[0]}") + return None, None + if not any(msg.get("role") == "assistant" for msg in dialog): + logger.warning("Dialog must have at least one assistant message") + return None, None + + # Convert to human/gpt format + role_map = {"user": "human", "assistant": "gpt"} + conversations = [] + for msg in dialog: + if "role" not in msg or "content" not in msg: + logger.warning(f"Message missing role or content: {msg}") + continue + conversations.append({"from": role_map[msg["role"]], "value": msg["content"]}) + + # Format as a single conversation + return conversations[0]["value"], conversations[1]["value"] + except json.JSONDecodeError: + logger.warning(f"Failed to parse dialog: {row['dialog']}") + return None, None + return None, None + + def _process_fallback_format(self, row: dict) -> tuple[str | None, str | None]: + """Process a row using fallback formats.""" + if "input" in row and "output" in row: + return row["input"], row["output"] + elif "prompt" in row and "completion" in row: + return row["prompt"], row["completion"] + elif "question" in row and "answer" in row: + return row["question"], row["answer"] + return None, None + + def _format_text(self, input_text: str, output_text: str, provider_config: HuggingFacePostTrainingConfig) -> str: + """Format input and output text based on model requirements.""" + if hasattr(provider_config, "chat_template"): + return provider_config.chat_template.format(input=input_text, output=output_text) + return f"{input_text}\n{output_text}" + + def _create_dataset( + self, rows: list[dict], config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig + ) -> Dataset: + """Create and preprocess the dataset.""" + formatted_rows = [] + for row in rows: + input_text = None + output_text = None + + # Process based on format + assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized" + if config.data_config.data_format.value == "instruct": + input_text, output_text = self._process_instruct_format(row) + elif config.data_config.data_format.value == "dialog": + input_text, output_text = self._process_dialog_format(row) + else: + input_text, output_text = self._process_fallback_format(row) + + if input_text and output_text: + formatted_text = self._format_text(input_text, output_text, provider_config) + formatted_rows.append({"text": formatted_text}) + + if not formatted_rows: + assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized" + raise ValueError( + f"No valid input/output pairs found in the dataset for format: {config.data_config.data_format.value}" + ) + + return Dataset.from_list(formatted_rows) + + def _preprocess_dataset( + self, ds: Dataset, tokenizer: AutoTokenizer, provider_config: HuggingFacePostTrainingConfig + ) -> Dataset: + """Preprocess the dataset with tokenizer.""" + + def tokenize_function(examples): + return tokenizer( + examples["text"], + padding=True, + truncation=True, + max_length=provider_config.max_seq_length, + return_tensors=None, + ) + + return ds.map( + tokenize_function, + batched=True, + remove_columns=ds.column_names, + ) + + async def load_dataset( + self, + model: str, + config: TrainingConfig, + provider_config: HuggingFacePostTrainingConfig, + ) -> tuple[Dataset, Dataset, AutoTokenizer]: + """Load and preprocess the dataset for training. + + Args: + model: The model identifier to load + config: Training configuration containing dataset settings + provider_config: Provider-specific configuration + + Returns: + tuple containing: + - Training dataset + - Evaluation dataset + - Tokenizer + + Raises: + ValueError: If dataset is missing required fields + RuntimeError: If tokenizer initialization fails + """ + assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized" + rows = await self._setup_data(config.data_config.dataset_id) + + # Validate that the dataset has the required fields for training + if not self.validate_dataset_format(rows): + raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input") + + # Initialize tokenizer with model-specific config + try: + tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config) + # Set up tokenizer defaults + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "right" + tokenizer.truncation_side = "right" + tokenizer.model_max_length = provider_config.max_seq_length + except Exception as e: + raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e + + # Create and preprocess dataset + try: + ds = self._create_dataset(rows, config, provider_config) + ds = self._preprocess_dataset(ds, tokenizer, provider_config) + except Exception as e: + raise ValueError(f"Failed to create dataset: {str(e)}") from e + + # Split dataset into train and validation + train_val_split = ds.train_test_split(test_size=0.1, seed=42) + return train_val_split["train"], train_val_split["test"], tokenizer + + def load_model( + self, + model: str, + device: torch.device, + provider_config: HuggingFacePostTrainingConfig, + ) -> AutoModelForCausalLM: + """Load and initialize the model for training. + + Args: + model: The model identifier to load + device: The device to load the model onto + provider_config: Provider-specific configuration + + Returns: + The loaded and initialized model + + Raises: + RuntimeError: If model loading fails + """ + logger.info("Loading the base model") + try: + model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config) + model_obj = AutoModelForCausalLM.from_pretrained( + model, + torch_dtype="auto", + quantization_config=None, + config=model_config, + **provider_config.model_specific_config, + ) + if model_obj.device != device: + model_obj = model_obj.to(device) + logger.info(f"Model loaded and moved to device: {model_obj.device}") + return model_obj + except Exception as e: + raise RuntimeError(f"Failed to load model: {str(e)}") from e + + async def train( + self, + model: str, + output_dir: str | None, + job_uuid: str, + lora_config: LoraFinetuningConfig, + config: TrainingConfig, + provider_config: HuggingFacePostTrainingConfig, + ) -> tuple[dict[str, Any], list[Checkpoint] | None]: + """Train a model using HuggingFace's SFTTrainer""" + try: + device = torch.device(provider_config.device) + except RuntimeError as e: + raise RuntimeError(f"Error getting Torch Device {str(e)}") from e + + # Detect device type and validate + if device.type == "cuda": + if not torch.cuda.is_available(): + raise RuntimeError( + f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device." + ) + # map unqualified 'cuda' to current device + if device.index is None: + device = torch.device(device.type, torch.cuda.current_device()) + elif device.type == "mps": + if not torch.backends.mps.is_available(): + raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.") + elif device.type == "hpu": + raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.") + + logger.info(f"Using device '{device}'") + output_dir_path = None + if output_dir: + output_dir_path = Path(output_dir) + + # Track memory stats throughout training + memory_stats = { + "initial": get_memory_stats(device), + "after_model_load": None, + "after_training": None, + "final": None, + } + + # Validate data config + if not config.data_config: + raise ValueError("DataConfig is required for training") + + # Load dataset and tokenizer + train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config, provider_config) + + # Load model with model-specific config + model_obj = self.load_model(model, device, provider_config) + memory_stats["after_model_load"] = get_memory_stats(device) + + # Configure LoRA + peft_config = None + if lora_config: + peft_config = LoraConfig( + lora_alpha=lora_config.alpha, + lora_dropout=0.1, + r=lora_config.rank, + bias="none", + task_type="CAUSAL_LM", + target_modules=lora_config.lora_attn_modules, + ) + + # Setup training arguments + lr = 2e-5 + if config.optimizer_config: + lr = config.optimizer_config.lr + + # Calculate steps per epoch and appropriate intervals + steps_per_epoch = len(train_dataset) // config.data_config.batch_size + eval_steps = max(1, steps_per_epoch // 10) # Evaluate 10 times per epoch + save_steps = max(1, steps_per_epoch // 5) # Save 5 times per epoch + logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch + + logger.info(f"Dataset size: {len(train_dataset)} examples") + logger.info(f"Batch size: {config.data_config.batch_size}") + logger.info(f"Steps per epoch: {steps_per_epoch}") + logger.info(f"Will evaluate every {eval_steps} steps") + logger.info(f"Will save every {save_steps} steps") + logger.info(f"Will log every {logging_steps} steps") + + # save_strategy should be none if output dir is none + save_strategy = "no" + if output_dir_path: + save_strategy = "steps" + training_arguments = SFTConfig( + max_steps=config.max_steps_per_epoch, + output_dir=str(output_dir_path) if output_dir_path is not None else None, + num_train_epochs=config.n_epochs, + per_device_train_batch_size=config.data_config.batch_size, + fp16=device.type == "cuda", + bf16=device.type != "cuda", + # use_cpu should only be set if we are on a "True" CPU machine, not a MPS enabled Mac due to stability issues. + use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False, + save_strategy=save_strategy, + save_steps=save_steps, + report_to="none", + max_seq_length=provider_config.max_seq_length, + gradient_accumulation_steps=config.gradient_accumulation_steps, + gradient_checkpointing=provider_config.gradient_checkpointing, + learning_rate=lr, + warmup_ratio=provider_config.warmup_ratio, + weight_decay=provider_config.weight_decay, + logging_steps=logging_steps, + # Enable validation + eval_strategy="steps", + eval_steps=eval_steps, + save_total_limit=provider_config.save_total_limit, + remove_unused_columns=False, + dataloader_pin_memory=provider_config.dataloader_pin_memory, + dataloader_num_workers=provider_config.dataloader_num_workers, + dataset_text_field="text", + packing=False, + # Add evaluation metrics + # loading the best model can only happen if we have saved a model + load_best_model_at_end=True if output_dir_path else False, + metric_for_best_model="eval_loss", + greater_is_better=False, + ) + + # Initialize trainer with both train and eval datasets + trainer = SFTTrainer( + model=model_obj, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=peft_config, + args=training_arguments, + ) + + # Train + logger.info("Starting training") + try: + trainer.train() + memory_stats["after_training"] = get_memory_stats(device) + + # Save final model + model_obj.config.use_cache = True + # if we have LoRA we need to do `merge_and_unload` + if lora_config: + model_obj = trainer.model.merge_and_unload() + else: + model_obj = trainer.model + + checkpoint = None + checkpoints = None + # only save a final model if checkpoint dir is specified + # this is especially useful to test training rather than saving of checkpoints + if output_dir_path: + model_obj.save_pretrained(output_dir_path / "merged_model") + + # Create checkpoint + checkpoint = Checkpoint( + identifier=f"{model}-sft-{config.n_epochs}", + created_at=datetime.now(timezone.utc), + epoch=config.n_epochs, + post_training_job_id=job_uuid, + path=str(output_dir_path / "merged_model"), + ) + checkpoints = [checkpoint] + + return memory_stats, checkpoints + finally: + # Clean up resources + if hasattr(trainer, "model"): + if device.type != "cpu": + trainer.model.to("cpu") + if device.type == "cuda": + torch.cuda.empty_cache() + del trainer.model + del trainer + gc.collect() + memory_stats["final"] = get_memory_stats(device) + + async def _setup_data( + self, + dataset_id: str, + ) -> list[dict[str, Any]]: + """Load dataset from llama stack dataset provider""" + try: + + async def fetch_rows(dataset_id: str): + return await self.datasetio_api.iterrows( + dataset_id=dataset_id, + limit=-1, + ) + + all_rows = await fetch_rows(dataset_id) + if not isinstance(all_rows.data, list): + raise RuntimeError("Expected dataset data to be a list") + return all_rows.data + except Exception as e: + raise RuntimeError(f"Failed to load dataset: {str(e)}") from e diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index 35567c07d..d752b8819 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -21,6 +21,17 @@ def available_providers() -> list[ProviderSpec]: Api.datasets, ], ), + InlineProviderSpec( + api=Api.post_training, + provider_type="inline::huggingface", + pip_packages=["torch", "trl", "transformers", "peft", "datasets"], + module="llama_stack.providers.inline.post_training.huggingface", + config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig", + api_dependencies=[ + Api.datasetio, + Api.datasets, + ], + ), remote_provider_spec( api=Api.post_training, adapter=AdapterSpec(