diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index d755ff0ae..c083da7d9 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -58,7 +58,7 @@ jobs: INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" run: | source .venv/bin/activate - nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv > server.log 2>&1 & + LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv & - name: Wait for Llama Stack server to be ready if: matrix.client-type == 'http' @@ -85,6 +85,11 @@ jobs: echo "Ollama health check failed" exit 1 fi + - name: Check Storage and Memory Available Before Tests + if: ${{ always() }} + run: | + free -h + df -h - name: Run Integration Tests env: @@ -100,12 +105,19 @@ jobs: --text-model="meta-llama/Llama-3.2-3B-Instruct" \ --embedding-model=all-MiniLM-L6-v2 + - name: Check Storage and Memory Available After Tests + if: ${{ always() }} + run: | + free -h + df -h + - name: Write ollama logs to file + if: ${{ always() }} run: | sudo journalctl -u ollama.service > ollama.log - name: Upload all logs to artifacts - if: always() + if: ${{ always() }} uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 with: name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }} diff --git a/docs/source/distributions/self_hosted_distro/ollama.md b/docs/source/distributions/self_hosted_distro/ollama.md index a3d67d4ce..4d148feda 100644 --- a/docs/source/distributions/self_hosted_distro/ollama.md +++ b/docs/source/distributions/self_hosted_distro/ollama.md @@ -19,6 +19,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov | datasetio | `remote::huggingface`, `inline::localfs` | | eval | `inline::meta-reference` | | inference | `remote::ollama` | +| post_training | `inline::huggingface` | | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | diff --git a/llama_stack/providers/inline/post_training/common/utils.py b/llama_stack/providers/inline/post_training/common/utils.py new file mode 100644 index 000000000..7840b21e8 --- /dev/null +++ b/llama_stack/providers/inline/post_training/common/utils.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import gc + + +def evacuate_model_from_device(model, device: str): + """Safely clear a model from memory and free device resources. + This function handles the proper cleanup of a model by: + 1. Moving the model to CPU if it's on a non-CPU device + 2. Deleting the model object to free memory + 3. Running garbage collection + 4. Clearing CUDA cache if the model was on a CUDA device + Args: + model: The PyTorch model to clear + device: The device type the model is currently on ('cuda', 'mps', 'cpu') + Note: + - For CUDA devices, this will clear the CUDA cache after moving the model to CPU + - For MPS devices, only moves the model to CPU (no cache clearing available) + - For CPU devices, only deletes the model object and runs garbage collection + """ + if device != "cpu": + model.to("cpu") + + del model + gc.collect() + + if device == "cuda": + # we need to import such that this is only imported when the method is called + import torch + + torch.cuda.empty_cache() diff --git a/llama_stack/providers/inline/post_training/huggingface/__init__.py b/llama_stack/providers/inline/post_training/huggingface/__init__.py new file mode 100644 index 000000000..cc1a671c1 --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.distribution.datatypes import Api + +from .config import HuggingFacePostTrainingConfig + +# post_training api and the huggingface provider is still experimental and under heavy development + + +async def get_provider_impl( + config: HuggingFacePostTrainingConfig, + deps: dict[Api, Any], +): + from .post_training import HuggingFacePostTrainingImpl + + impl = HuggingFacePostTrainingImpl( + config, + deps[Api.datasetio], + deps[Api.datasets], + ) + return impl diff --git a/llama_stack/providers/inline/post_training/huggingface/config.py b/llama_stack/providers/inline/post_training/huggingface/config.py new file mode 100644 index 000000000..06c6d8073 --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface/config.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Literal + +from pydantic import BaseModel + + +class HuggingFacePostTrainingConfig(BaseModel): + # Device to run training on (cuda, cpu, mps) + device: str = "cuda" + + # Distributed training backend if using multiple devices + # fsdp: Fully Sharded Data Parallel + # deepspeed: DeepSpeed ZeRO optimization + distributed_backend: Literal["fsdp", "deepspeed"] | None = None + + # Format for saving model checkpoints + # full_state: Save complete model state + # huggingface: Save in HuggingFace format (recommended for compatibility) + checkpoint_format: Literal["full_state", "huggingface"] | None = "huggingface" + + # Template for formatting chat inputs and outputs + # Used to structure the conversation format for training + chat_template: str = "<|user|>\n{input}\n<|assistant|>\n{output}" + + # Model-specific configuration parameters + # trust_remote_code: Allow execution of custom model code + # attn_implementation: Use SDPA (Scaled Dot Product Attention) for better performance + model_specific_config: dict = { + "trust_remote_code": True, + "attn_implementation": "sdpa", + } + + # Maximum sequence length for training + # Set to 2048 as this is the maximum that works reliably on MPS (Apple Silicon) + # Longer sequences may cause memory issues on MPS devices + max_seq_length: int = 2048 + + # Enable gradient checkpointing to reduce memory usage + # Trades computation for memory by recomputing activations + gradient_checkpointing: bool = False + + # Maximum number of checkpoints to keep + # Older checkpoints are deleted when this limit is reached + save_total_limit: int = 3 + + # Number of training steps between logging updates + logging_steps: int = 10 + + # Ratio of training steps used for learning rate warmup + # Helps stabilize early training + warmup_ratio: float = 0.1 + + # L2 regularization coefficient + # Helps prevent overfitting + weight_decay: float = 0.01 + + # Number of worker processes for data loading + # Higher values can improve data loading speed but increase memory usage + dataloader_num_workers: int = 4 + + # Whether to pin memory in data loader + # Can improve data transfer speed to GPU but uses more memory + dataloader_pin_memory: bool = True + + @classmethod + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: + return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"} diff --git a/llama_stack/providers/inline/post_training/huggingface/post_training.py b/llama_stack/providers/inline/post_training/huggingface/post_training.py new file mode 100644 index 000000000..0b2760792 --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface/post_training.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from enum import Enum +from typing import Any + +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.post_training import ( + AlgorithmConfig, + Checkpoint, + DPOAlignmentConfig, + JobStatus, + ListPostTrainingJobsResponse, + PostTrainingJob, + PostTrainingJobArtifactsResponse, + PostTrainingJobStatusResponse, + TrainingConfig, +) +from llama_stack.providers.inline.post_training.huggingface.config import ( + HuggingFacePostTrainingConfig, +) +from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import ( + HFFinetuningSingleDevice, +) +from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler +from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus +from llama_stack.schema_utils import webmethod + + +class TrainingArtifactType(Enum): + CHECKPOINT = "checkpoint" + RESOURCES_STATS = "resources_stats" + + +_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune" + + +class HuggingFacePostTrainingImpl: + def __init__( + self, + config: HuggingFacePostTrainingConfig, + datasetio_api: DatasetIO, + datasets: Datasets, + ) -> None: + self.config = config + self.datasetio_api = datasetio_api + self.datasets_api = datasets + self._scheduler = Scheduler() + + async def shutdown(self) -> None: + await self._scheduler.shutdown() + + @staticmethod + def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact: + return JobArtifact( + type=TrainingArtifactType.CHECKPOINT.value, + name=checkpoint.identifier, + uri=checkpoint.path, + metadata=dict(checkpoint), + ) + + @staticmethod + def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact: + return JobArtifact( + type=TrainingArtifactType.RESOURCES_STATS.value, + name=TrainingArtifactType.RESOURCES_STATS.value, + metadata=resources_stats, + ) + + async def supervised_fine_tune( + self, + job_uuid: str, + training_config: TrainingConfig, + hyperparam_search_config: dict[str, Any], + logger_config: dict[str, Any], + model: str, + checkpoint_dir: str | None = None, + algorithm_config: AlgorithmConfig | None = None, + ) -> PostTrainingJob: + async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb): + on_log_message_cb("Starting HF finetuning") + + recipe = HFFinetuningSingleDevice( + job_uuid=job_uuid, + datasetio_api=self.datasetio_api, + datasets_api=self.datasets_api, + ) + + resources_allocated, checkpoints = await recipe.train( + model=model, + output_dir=checkpoint_dir, + job_uuid=job_uuid, + lora_config=algorithm_config, + config=training_config, + provider_config=self.config, + ) + + on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated)) + if checkpoints: + for checkpoint in checkpoints: + artifact = self._checkpoint_to_artifact(checkpoint) + on_artifact_collected_cb(artifact) + + on_status_change_cb(SchedulerJobStatus.completed) + on_log_message_cb("HF finetuning completed") + + job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler) + return PostTrainingJob(job_uuid=job_uuid) + + async def preference_optimize( + self, + job_uuid: str, + finetuned_model: str, + algorithm_config: DPOAlignmentConfig, + training_config: TrainingConfig, + hyperparam_search_config: dict[str, Any], + logger_config: dict[str, Any], + ) -> PostTrainingJob: + raise NotImplementedError("DPO alignment is not implemented yet") + + async def get_training_jobs(self) -> ListPostTrainingJobsResponse: + return ListPostTrainingJobsResponse( + data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()] + ) + + @staticmethod + def _get_artifacts_metadata_by_type(job, artifact_type): + return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type] + + @classmethod + def _get_checkpoints(cls, job): + return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value) + + @classmethod + def _get_resources_allocated(cls, job): + data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value) + return data[0] if data else None + + @webmethod(route="/post-training/job/status") + async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None: + job = self._scheduler.get_job(job_uuid) + + match job.status: + # TODO: Add support for other statuses to API + case SchedulerJobStatus.new | SchedulerJobStatus.scheduled: + status = JobStatus.scheduled + case SchedulerJobStatus.running: + status = JobStatus.in_progress + case SchedulerJobStatus.completed: + status = JobStatus.completed + case SchedulerJobStatus.failed: + status = JobStatus.failed + case _: + raise NotImplementedError() + + return PostTrainingJobStatusResponse( + job_uuid=job_uuid, + status=status, + scheduled_at=job.scheduled_at, + started_at=job.started_at, + completed_at=job.completed_at, + checkpoints=self._get_checkpoints(job), + resources_allocated=self._get_resources_allocated(job), + ) + + @webmethod(route="/post-training/job/cancel") + async def cancel_training_job(self, job_uuid: str) -> None: + self._scheduler.cancel(job_uuid) + + @webmethod(route="/post-training/job/artifacts") + async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None: + job = self._scheduler.get_job(job_uuid) + return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job)) diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py new file mode 100644 index 000000000..b6d13b029 --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py @@ -0,0 +1,683 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import gc +import json +import logging +import multiprocessing +import os +import signal +import sys +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import psutil + +from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device + +# Set tokenizer parallelism environment variable +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# Force PyTorch to use OpenBLAS instead of MKL +os.environ["MKL_THREADING_LAYER"] = "GNU" +os.environ["MKL_SERVICE_FORCE_INTEL"] = "0" +os.environ["MKL_NUM_THREADS"] = "1" + +import torch +from datasets import Dataset +from peft import LoraConfig +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, +) +from trl import SFTConfig, SFTTrainer + +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.post_training import ( + Checkpoint, + DataConfig, + LoraFinetuningConfig, + TrainingConfig, +) + +from ..config import HuggingFacePostTrainingConfig + +logger = logging.getLogger(__name__) + + +def get_gb(to_convert: int) -> str: + """Converts memory stats to GB and formats to 2 decimal places. + Args: + to_convert: Memory value in bytes + Returns: + str: Memory value in GB formatted to 2 decimal places + """ + return f"{(to_convert / (1024**3)):.2f}" + + +def get_memory_stats(device: torch.device) -> dict[str, Any]: + """Get memory statistics for the given device.""" + stats = { + "system_memory": { + "total": get_gb(psutil.virtual_memory().total), + "available": get_gb(psutil.virtual_memory().available), + "used": get_gb(psutil.virtual_memory().used), + "percent": psutil.virtual_memory().percent, + } + } + + if device.type == "cuda": + stats["device_memory"] = { + "allocated": get_gb(torch.cuda.memory_allocated(device)), + "reserved": get_gb(torch.cuda.memory_reserved(device)), + "max_allocated": get_gb(torch.cuda.max_memory_allocated(device)), + } + elif device.type == "mps": + # MPS doesn't provide direct memory stats, but we can track system memory + stats["device_memory"] = { + "note": "MPS memory stats not directly available", + "system_memory_used": get_gb(psutil.virtual_memory().used), + } + elif device.type == "cpu": + # For CPU, we track process memory usage + process = psutil.Process() + stats["device_memory"] = { + "process_rss": get_gb(process.memory_info().rss), + "process_vms": get_gb(process.memory_info().vms), + "process_percent": process.memory_percent(), + } + + return stats + + +def setup_torch_device(device_str: str) -> torch.device: + """Initialize and validate a PyTorch device. + This function handles device initialization and validation for different device types: + - CUDA: Validates CUDA availability and handles device selection + - MPS: Validates MPS availability for Apple Silicon + - CPU: Basic validation + - HPU: Raises error as it's not supported + Args: + device_str: String specifying the device ('cuda', 'cpu', 'mps') + Returns: + torch.device: The initialized and validated device + Raises: + RuntimeError: If device initialization fails or device is not supported + """ + try: + device = torch.device(device_str) + except RuntimeError as e: + raise RuntimeError(f"Error getting Torch Device {str(e)}") from e + + # Validate device capabilities + if device.type == "cuda": + if not torch.cuda.is_available(): + raise RuntimeError( + f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device." + ) + if device.index is None: + device = torch.device(device.type, torch.cuda.current_device()) + elif device.type == "mps": + if not torch.backends.mps.is_available(): + raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.") + elif device.type == "hpu": + raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.") + + return device + + +class HFFinetuningSingleDevice: + def __init__( + self, + job_uuid: str, + datasetio_api: DatasetIO, + datasets_api: Datasets, + ): + self.datasetio_api = datasetio_api + self.datasets_api = datasets_api + self.job_uuid = job_uuid + + def validate_dataset_format(self, rows: list[dict]) -> bool: + """Validate that the dataset has the required fields.""" + required_fields = ["input_query", "expected_answer", "chat_completion_input"] + return all(field in row for row in rows for field in required_fields) + + def _process_instruct_format(self, row: dict) -> tuple[str | None, str | None]: + """Process a row in instruct format.""" + if "chat_completion_input" in row and "expected_answer" in row: + try: + messages = json.loads(row["chat_completion_input"]) + if not isinstance(messages, list) or len(messages) != 1: + logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}") + return None, None + if "content" not in messages[0]: + logger.warning(f"Message missing content: {messages[0]}") + return None, None + return messages[0]["content"], row["expected_answer"] + except json.JSONDecodeError: + logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}") + return None, None + return None, None + + def _process_dialog_format(self, row: dict) -> tuple[str | None, str | None]: + """Process a row in dialog format.""" + if "dialog" in row: + try: + dialog = json.loads(row["dialog"]) + if not isinstance(dialog, list) or len(dialog) < 2: + logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}") + return None, None + if dialog[0].get("role") != "user": + logger.warning(f"First message must be from user: {dialog[0]}") + return None, None + if not any(msg.get("role") == "assistant" for msg in dialog): + logger.warning("Dialog must have at least one assistant message") + return None, None + + # Convert to human/gpt format + role_map = {"user": "human", "assistant": "gpt"} + conversations = [] + for msg in dialog: + if "role" not in msg or "content" not in msg: + logger.warning(f"Message missing role or content: {msg}") + continue + conversations.append({"from": role_map[msg["role"]], "value": msg["content"]}) + + # Format as a single conversation + return conversations[0]["value"], conversations[1]["value"] + except json.JSONDecodeError: + logger.warning(f"Failed to parse dialog: {row['dialog']}") + return None, None + return None, None + + def _process_fallback_format(self, row: dict) -> tuple[str | None, str | None]: + """Process a row using fallback formats.""" + if "input" in row and "output" in row: + return row["input"], row["output"] + elif "prompt" in row and "completion" in row: + return row["prompt"], row["completion"] + elif "question" in row and "answer" in row: + return row["question"], row["answer"] + return None, None + + def _format_text(self, input_text: str, output_text: str, provider_config: HuggingFacePostTrainingConfig) -> str: + """Format input and output text based on model requirements.""" + if hasattr(provider_config, "chat_template"): + return provider_config.chat_template.format(input=input_text, output=output_text) + return f"{input_text}\n{output_text}" + + def _create_dataset( + self, rows: list[dict], config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig + ) -> Dataset: + """Create and preprocess the dataset.""" + formatted_rows = [] + for row in rows: + input_text = None + output_text = None + + # Process based on format + assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized" + if config.data_config.data_format.value == "instruct": + input_text, output_text = self._process_instruct_format(row) + elif config.data_config.data_format.value == "dialog": + input_text, output_text = self._process_dialog_format(row) + else: + input_text, output_text = self._process_fallback_format(row) + + if input_text and output_text: + formatted_text = self._format_text(input_text, output_text, provider_config) + formatted_rows.append({"text": formatted_text}) + + if not formatted_rows: + assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized" + raise ValueError( + f"No valid input/output pairs found in the dataset for format: {config.data_config.data_format.value}" + ) + + return Dataset.from_list(formatted_rows) + + def _preprocess_dataset( + self, ds: Dataset, tokenizer: AutoTokenizer, provider_config: HuggingFacePostTrainingConfig + ) -> Dataset: + """Preprocess the dataset with tokenizer.""" + + def tokenize_function(examples): + return tokenizer( + examples["text"], + padding=True, + truncation=True, + max_length=provider_config.max_seq_length, + return_tensors=None, + ) + + return ds.map( + tokenize_function, + batched=True, + remove_columns=ds.column_names, + ) + + async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]: + """Load dataset from llama stack dataset provider""" + try: + all_rows = await self.datasetio_api.iterrows( + dataset_id=dataset_id, + limit=-1, + ) + if not isinstance(all_rows.data, list): + raise RuntimeError("Expected dataset data to be a list") + return all_rows.data + except Exception as e: + raise RuntimeError(f"Failed to load dataset: {str(e)}") from e + + def _run_training_sync( + self, + model: str, + provider_config: dict[str, Any], + peft_config: LoraConfig | None, + config: dict[str, Any], + output_dir_path: Path | None, + ) -> None: + """Synchronous wrapper for running training process. + This method serves as a bridge between the multiprocessing Process and the async training function. + It creates a new event loop to run the async training process. + Args: + model: The model identifier to load + dataset_id: ID of the dataset to use for training + provider_config: Configuration specific to the HuggingFace provider + peft_config: Optional LoRA configuration + config: General training configuration + output_dir_path: Optional path to save the model + """ + import asyncio + + logger.info("Starting training process with async wrapper") + asyncio.run( + self._run_training( + model=model, + provider_config=provider_config, + peft_config=peft_config, + config=config, + output_dir_path=output_dir_path, + ) + ) + + async def load_dataset( + self, + model: str, + config: TrainingConfig, + provider_config: HuggingFacePostTrainingConfig, + ) -> tuple[Dataset, Dataset, AutoTokenizer]: + """Load and prepare the dataset for training. + Args: + model: The model identifier to load + config: Training configuration + provider_config: Provider-specific configuration + Returns: + tuple: (train_dataset, eval_dataset, tokenizer) + """ + # Validate data config + if not config.data_config: + raise ValueError("DataConfig is required for training") + + # Load dataset + logger.info(f"Loading dataset: {config.data_config.dataset_id}") + rows = await self._setup_data(config.data_config.dataset_id) + if not self.validate_dataset_format(rows): + raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input") + logger.info(f"Loaded {len(rows)} rows from dataset") + + # Initialize tokenizer + logger.info(f"Initializing tokenizer for model: {model}") + try: + tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config) + + # Set pad token to eos token if not present + # This is common for models that don't have a dedicated pad token + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + + # Set padding side to right for causal language modeling + # This ensures that padding tokens don't interfere with the model's ability + # to predict the next token in the sequence + tokenizer.padding_side = "right" + + # Set truncation side to right to keep the beginning of the sequence + # This is important for maintaining context and instruction format + tokenizer.truncation_side = "right" + + # Set model max length to match provider config + # This ensures consistent sequence lengths across the training process + tokenizer.model_max_length = provider_config.max_seq_length + + logger.info("Tokenizer initialized successfully") + except Exception as e: + raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e + + # Create and preprocess dataset + logger.info("Creating and preprocessing dataset") + try: + ds = self._create_dataset(rows, config, provider_config) + ds = self._preprocess_dataset(ds, tokenizer, provider_config) + logger.info(f"Dataset created with {len(ds)} examples") + except Exception as e: + raise ValueError(f"Failed to create dataset: {str(e)}") from e + + # Split dataset + logger.info("Splitting dataset into train and validation sets") + train_val_split = ds.train_test_split(test_size=0.1, seed=42) + train_dataset = train_val_split["train"] + eval_dataset = train_val_split["test"] + logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples") + + return train_dataset, eval_dataset, tokenizer + + def load_model( + self, + model: str, + device: torch.device, + provider_config: HuggingFacePostTrainingConfig, + ) -> AutoModelForCausalLM: + """Load and initialize the model for training. + Args: + model: The model identifier to load + device: The device to load the model onto + provider_config: Provider-specific configuration + Returns: + The loaded and initialized model + Raises: + RuntimeError: If model loading fails + """ + logger.info("Loading the base model") + try: + model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config) + model_obj = AutoModelForCausalLM.from_pretrained( + model, + torch_dtype="auto" if device.type != "cpu" else "float32", + quantization_config=None, + config=model_config, + **provider_config.model_specific_config, + ) + # Always move model to specified device + model_obj = model_obj.to(device) + logger.info(f"Model loaded and moved to device: {model_obj.device}") + return model_obj + except Exception as e: + raise RuntimeError(f"Failed to load model: {str(e)}") from e + + def setup_training_args( + self, + config: TrainingConfig, + provider_config: HuggingFacePostTrainingConfig, + device: torch.device, + output_dir_path: Path | None, + steps_per_epoch: int, + ) -> SFTConfig: + """Setup training arguments. + Args: + config: Training configuration + provider_config: Provider-specific configuration + device: The device to train on + output_dir_path: Optional path to save the model + steps_per_epoch: Number of steps per epoch + Returns: + Configured SFTConfig object + """ + logger.info("Configuring training arguments") + lr = 2e-5 + if config.optimizer_config: + lr = config.optimizer_config.lr + logger.info(f"Using custom learning rate: {lr}") + + # Validate data config + if not config.data_config: + raise ValueError("DataConfig is required for training") + data_config = config.data_config + + # Calculate steps + total_steps = steps_per_epoch * config.n_epochs + max_steps = min(config.max_steps_per_epoch, total_steps) + eval_steps = max(1, steps_per_epoch // 10) # Evaluate 10 times per epoch + save_steps = max(1, steps_per_epoch // 5) # Save 5 times per epoch + logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch + + logger.info("Training configuration:") + logger.info(f"- Steps per epoch: {steps_per_epoch}") + logger.info(f"- Total steps: {total_steps}") + logger.info(f"- Max steps: {max_steps}") + logger.info(f"- Eval steps: {eval_steps}") + logger.info(f"- Save steps: {save_steps}") + logger.info(f"- Logging steps: {logging_steps}") + + # Configure save strategy + save_strategy = "no" + if output_dir_path: + save_strategy = "steps" + logger.info(f"Will save checkpoints to {output_dir_path}") + + return SFTConfig( + max_steps=max_steps, + output_dir=str(output_dir_path) if output_dir_path is not None else None, + num_train_epochs=config.n_epochs, + per_device_train_batch_size=data_config.batch_size, + fp16=device.type == "cuda", + bf16=False, # Causes CPU issues. + eval_strategy="steps", + use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False, + save_strategy=save_strategy, + report_to="none", + max_seq_length=provider_config.max_seq_length, + gradient_accumulation_steps=config.gradient_accumulation_steps, + gradient_checkpointing=provider_config.gradient_checkpointing, + learning_rate=lr, + warmup_ratio=provider_config.warmup_ratio, + weight_decay=provider_config.weight_decay, + remove_unused_columns=False, + dataloader_pin_memory=provider_config.dataloader_pin_memory, + dataloader_num_workers=provider_config.dataloader_num_workers, + dataset_text_field="text", + packing=False, + load_best_model_at_end=True if output_dir_path else False, + metric_for_best_model="eval_loss", + greater_is_better=False, + eval_steps=eval_steps, + save_steps=save_steps, + logging_steps=logging_steps, + ) + + def save_model( + self, + model_obj: AutoModelForCausalLM, + trainer: SFTTrainer, + peft_config: LoraConfig | None, + output_dir_path: Path, + ) -> None: + """Save the trained model. + Args: + model_obj: The model to save + trainer: The trainer instance + peft_config: Optional LoRA configuration + output_dir_path: Path to save the model + """ + logger.info("Saving final model") + model_obj.config.use_cache = True + + if peft_config: + logger.info("Merging LoRA weights with base model") + model_obj = trainer.model.merge_and_unload() + else: + model_obj = trainer.model + + save_path = output_dir_path / "merged_model" + logger.info(f"Saving model to {save_path}") + model_obj.save_pretrained(save_path) + + async def _run_training( + self, + model: str, + provider_config: dict[str, Any], + peft_config: LoraConfig | None, + config: dict[str, Any], + output_dir_path: Path | None, + ) -> None: + """Run the training process with signal handling.""" + + def signal_handler(signum, frame): + """Handle termination signals gracefully.""" + logger.info(f"Received signal {signum}, initiating graceful shutdown") + sys.exit(0) + + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + # Convert config dicts back to objects + logger.info("Initializing configuration objects") + provider_config_obj = HuggingFacePostTrainingConfig(**provider_config) + config_obj = TrainingConfig(**config) + + # Initialize and validate device + device = setup_torch_device(provider_config_obj.device) + logger.info(f"Using device '{device}'") + + # Load dataset and tokenizer + train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj) + + # Calculate steps per epoch + if not config_obj.data_config: + raise ValueError("DataConfig is required for training") + steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size + + # Setup training arguments + training_args = self.setup_training_args( + config_obj, + provider_config_obj, + device, + output_dir_path, + steps_per_epoch, + ) + + # Load model + model_obj = self.load_model(model, device, provider_config_obj) + + # Initialize trainer + logger.info("Initializing SFTTrainer") + trainer = SFTTrainer( + model=model_obj, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=peft_config, + args=training_args, + ) + + try: + # Train + logger.info("Starting training") + trainer.train() + logger.info("Training completed successfully") + + # Save final model if output directory is provided + if output_dir_path: + self.save_model(model_obj, trainer, peft_config, output_dir_path) + + finally: + # Clean up resources + logger.info("Cleaning up resources") + if hasattr(trainer, "model"): + evacuate_model_from_device(trainer.model, device.type) + del trainer + gc.collect() + logger.info("Cleanup completed") + + async def train( + self, + model: str, + output_dir: str | None, + job_uuid: str, + lora_config: LoraFinetuningConfig, + config: TrainingConfig, + provider_config: HuggingFacePostTrainingConfig, + ) -> tuple[dict[str, Any], list[Checkpoint] | None]: + """Train a model using HuggingFace's SFTTrainer""" + # Initialize and validate device + device = setup_torch_device(provider_config.device) + logger.info(f"Using device '{device}'") + + output_dir_path = None + if output_dir: + output_dir_path = Path(output_dir) + + # Track memory stats + memory_stats = { + "initial": get_memory_stats(device), + "after_training": None, + "final": None, + } + + # Configure LoRA + peft_config = None + if lora_config: + peft_config = LoraConfig( + lora_alpha=lora_config.alpha, + lora_dropout=0.1, + r=lora_config.rank, + bias="none", + task_type="CAUSAL_LM", + target_modules=lora_config.lora_attn_modules, + ) + + # Validate data config + if not config.data_config: + raise ValueError("DataConfig is required for training") + + # Train in a separate process + logger.info("Starting training in separate process") + try: + # Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility + if device.type in ["cuda", "mps"]: + multiprocessing.set_start_method("spawn", force=True) + + process = multiprocessing.Process( + target=self._run_training_sync, + kwargs={ + "model": model, + "provider_config": provider_config.model_dump(), + "peft_config": peft_config, + "config": config.model_dump(), + "output_dir_path": output_dir_path, + }, + ) + process.start() + + # Monitor the process + while process.is_alive(): + process.join(timeout=1) # Check every second + if not process.is_alive(): + break + + # Get the return code + if process.exitcode != 0: + raise RuntimeError(f"Training failed with exit code {process.exitcode}") + + memory_stats["after_training"] = get_memory_stats(device) + + checkpoints = None + if output_dir_path: + # Create checkpoint + checkpoint = Checkpoint( + identifier=f"{model}-sft-{config.n_epochs}", + created_at=datetime.now(timezone.utc), + epoch=config.n_epochs, + post_training_job_id=job_uuid, + path=str(output_dir_path / "merged_model"), + ) + checkpoints = [checkpoint] + + return memory_stats, checkpoints + finally: + memory_stats["final"] = get_memory_stats(device) + gc.collect() diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index b5a495935..f56dd2499 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import gc import logging import os import time @@ -47,6 +46,7 @@ from llama_stack.apis.post_training import ( from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.models.llama.sku_list import resolve_model +from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from llama_stack.providers.inline.post_training.torchtune.common import utils from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( TorchtuneCheckpointer, @@ -554,11 +554,7 @@ class LoraFinetuningSingleDevice: checkpoints.append(checkpoint) # clean up the memory after training finishes - if self._device.type != "cpu": - self._model.to("cpu") - torch.cuda.empty_cache() - del self._model - gc.collect() + evacuate_model_from_device(self._model, self._device.type) return (memory_stats, checkpoints) diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index 35567c07d..d752b8819 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -21,6 +21,17 @@ def available_providers() -> list[ProviderSpec]: Api.datasets, ], ), + InlineProviderSpec( + api=Api.post_training, + provider_type="inline::huggingface", + pip_packages=["torch", "trl", "transformers", "peft", "datasets"], + module="llama_stack.providers.inline.post_training.huggingface", + config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig", + api_dependencies=[ + Api.datasetio, + Api.datasets, + ], + ), remote_provider_spec( api=Api.post_training, adapter=AdapterSpec( diff --git a/llama_stack/templates/dependencies.json b/llama_stack/templates/dependencies.json index d1a17e48e..fb4ab9fda 100644 --- a/llama_stack/templates/dependencies.json +++ b/llama_stack/templates/dependencies.json @@ -441,6 +441,7 @@ "opentelemetry-exporter-otlp-proto-http", "opentelemetry-sdk", "pandas", + "peft", "pillow", "psycopg2-binary", "pymongo", @@ -451,9 +452,11 @@ "scikit-learn", "scipy", "sentencepiece", + "torch", "tqdm", "transformers", "tree_sitter", + "trl", "uvicorn" ], "open-benchmark": [ diff --git a/llama_stack/templates/experimental-post-training/build.yaml b/llama_stack/templates/experimental-post-training/build.yaml index b4b5e2203..55cd189c6 100644 --- a/llama_stack/templates/experimental-post-training/build.yaml +++ b/llama_stack/templates/experimental-post-training/build.yaml @@ -13,9 +13,10 @@ distribution_spec: - inline::basic - inline::braintrust post_training: - - inline::torchtune + - inline::huggingface datasetio: - inline::localfs + - remote::huggingface telemetry: - inline::meta-reference agents: diff --git a/llama_stack/templates/experimental-post-training/run.yaml b/llama_stack/templates/experimental-post-training/run.yaml index 2ebdfe1aa..393cba41d 100644 --- a/llama_stack/templates/experimental-post-training/run.yaml +++ b/llama_stack/templates/experimental-post-training/run.yaml @@ -49,16 +49,24 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/localfs_datasetio.db + - provider_id: huggingface + provider_type: remote::huggingface + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/huggingface}/huggingface_datasetio.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference config: {} post_training: - - provider_id: torchtune-post-training - provider_type: inline::torchtune - config: { + - provider_id: huggingface + provider_type: inline::huggingface + config: checkpoint_format: huggingface - } + distributed_backend: null + device: cpu agents: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml index 88e61bf8a..7d5363575 100644 --- a/llama_stack/templates/ollama/build.yaml +++ b/llama_stack/templates/ollama/build.yaml @@ -23,6 +23,8 @@ distribution_spec: - inline::basic - inline::llm-as-judge - inline::braintrust + post_training: + - inline::huggingface tool_runtime: - remote::brave-search - remote::tavily-search diff --git a/llama_stack/templates/ollama/ollama.py b/llama_stack/templates/ollama/ollama.py index d72d299ec..0b4f05128 100644 --- a/llama_stack/templates/ollama/ollama.py +++ b/llama_stack/templates/ollama/ollama.py @@ -13,6 +13,7 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) +from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -28,6 +29,7 @@ def get_distribution_template() -> DistributionTemplate: "eval": ["inline::meta-reference"], "datasetio": ["remote::huggingface", "inline::localfs"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "post_training": ["inline::huggingface"], "tool_runtime": [ "remote::brave-search", "remote::tavily-search", @@ -47,7 +49,11 @@ def get_distribution_template() -> DistributionTemplate: provider_type="inline::faiss", config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), ) - + posttraining_provider = Provider( + provider_id="huggingface", + provider_type="inline::huggingface", + config=HuggingFacePostTrainingConfig.sample_run_config(f"~/.llama/distributions/{name}"), + ) inference_model = ModelInput( model_id="${env.INFERENCE_MODEL}", provider_id="ollama", @@ -92,6 +98,7 @@ def get_distribution_template() -> DistributionTemplate: provider_overrides={ "inference": [inference_provider], "vector_io": [vector_io_provider_faiss], + "post_training": [posttraining_provider], }, default_models=[inference_model, embedding_model], default_tool_groups=default_tool_groups, @@ -100,6 +107,7 @@ def get_distribution_template() -> DistributionTemplate: provider_overrides={ "inference": [inference_provider], "vector_io": [vector_io_provider_faiss], + "post_training": [posttraining_provider], "safety": [ Provider( provider_id="llama-guard", diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml index 651d58117..74d0822ca 100644 --- a/llama_stack/templates/ollama/run-with-safety.yaml +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -5,6 +5,7 @@ apis: - datasetio - eval - inference +- post_training - safety - scoring - telemetry @@ -80,6 +81,13 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + post_training: + - provider_id: huggingface + provider_type: inline::huggingface + config: + checkpoint_format: huggingface + distributed_backend: null + device: cpu tool_runtime: - provider_id: brave-search provider_type: remote::brave-search diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index 1372486fe..71229be97 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -5,6 +5,7 @@ apis: - datasetio - eval - inference +- post_training - safety - scoring - telemetry @@ -78,6 +79,13 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + post_training: + - provider_id: huggingface + provider_type: inline::huggingface + config: + checkpoint_format: huggingface + distributed_backend: null + device: cpu tool_runtime: - provider_id: brave-search provider_type: remote::brave-search diff --git a/pyproject.toml b/pyproject.toml index ba7c2300a..1fe64f350 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ [project.optional-dependencies] dev = [ "pytest", + "pytest-timeout", "pytest-asyncio", "pytest-cov", "pytest-html", diff --git a/requirements.txt b/requirements.txt index 0857a9886..c2571025a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,206 +1,69 @@ # This file was autogenerated by uv via the following command: # uv export --frozen --no-hashes --no-emit-project --output-file=requirements.txt annotated-types==0.7.0 - # via pydantic anyio==4.8.0 - # via - # httpx - # llama-stack-client - # openai attrs==25.1.0 - # via - # jsonschema - # referencing blobfile==3.0.0 - # via llama-stack cachetools==5.5.2 - # via google-auth certifi==2025.1.31 - # via - # httpcore - # httpx - # kubernetes - # requests charset-normalizer==3.4.1 - # via requests click==8.1.8 - # via llama-stack-client colorama==0.4.6 ; sys_platform == 'win32' - # via - # click - # tqdm distro==1.9.0 - # via - # llama-stack-client - # openai durationpy==0.9 - # via kubernetes exceptiongroup==1.2.2 ; python_full_version < '3.11' - # via anyio filelock==3.17.0 - # via - # blobfile - # huggingface-hub fire==0.7.0 - # via llama-stack fsspec==2024.12.0 - # via huggingface-hub google-auth==2.38.0 - # via kubernetes h11==0.16.0 - # via - # httpcore - # llama-stack httpcore==1.0.9 - # via httpx httpx==0.28.1 - # via - # llama-stack - # llama-stack-client - # openai huggingface-hub==0.29.0 - # via llama-stack idna==3.10 - # via - # anyio - # httpx - # requests jinja2==3.1.6 - # via llama-stack jiter==0.8.2 - # via openai jsonschema==4.23.0 - # via llama-stack jsonschema-specifications==2024.10.1 - # via jsonschema kubernetes==32.0.1 - # via llama-stack llama-stack-client==0.2.7 - # via llama-stack lxml==5.3.1 - # via blobfile markdown-it-py==3.0.0 - # via rich markupsafe==3.0.2 - # via jinja2 mdurl==0.1.2 - # via markdown-it-py numpy==2.2.3 - # via pandas oauthlib==3.2.2 - # via - # kubernetes - # requests-oauthlib openai==1.71.0 - # via llama-stack packaging==24.2 - # via huggingface-hub pandas==2.2.3 - # via llama-stack-client pillow==11.1.0 - # via llama-stack prompt-toolkit==3.0.50 - # via - # llama-stack - # llama-stack-client pyaml==25.1.0 - # via llama-stack-client pyasn1==0.6.1 - # via - # pyasn1-modules - # rsa pyasn1-modules==0.4.2 - # via google-auth pycryptodomex==3.21.0 - # via blobfile pydantic==2.10.6 - # via - # llama-stack - # llama-stack-client - # openai pydantic-core==2.27.2 - # via pydantic pygments==2.19.1 - # via rich python-dateutil==2.9.0.post0 - # via - # kubernetes - # pandas python-dotenv==1.0.1 - # via llama-stack pytz==2025.1 - # via pandas pyyaml==6.0.2 - # via - # huggingface-hub - # kubernetes - # pyaml referencing==0.36.2 - # via - # jsonschema - # jsonschema-specifications regex==2024.11.6 - # via tiktoken requests==2.32.3 - # via - # huggingface-hub - # kubernetes - # llama-stack - # requests-oauthlib - # tiktoken requests-oauthlib==2.0.0 - # via kubernetes rich==13.9.4 - # via - # llama-stack - # llama-stack-client rpds-py==0.22.3 - # via - # jsonschema - # referencing rsa==4.9 - # via google-auth setuptools==75.8.0 - # via llama-stack six==1.17.0 - # via - # kubernetes - # python-dateutil sniffio==1.3.1 - # via - # anyio - # llama-stack-client - # openai termcolor==2.5.0 - # via - # fire - # llama-stack - # llama-stack-client tiktoken==0.9.0 - # via llama-stack tqdm==4.67.1 - # via - # huggingface-hub - # llama-stack-client - # openai typing-extensions==4.12.2 - # via - # anyio - # huggingface-hub - # llama-stack-client - # openai - # pydantic - # pydantic-core - # referencing - # rich tzdata==2025.1 - # via pandas urllib3==2.3.0 - # via - # blobfile - # kubernetes - # requests wcwidth==0.2.13 - # via prompt-toolkit websocket-client==1.8.0 - # via kubernetes diff --git a/tests/integration/post_training/test_post_training.py b/tests/integration/post_training/test_post_training.py index 648ace9d6..bb4639d17 100644 --- a/tests/integration/post_training/test_post_training.py +++ b/tests/integration/post_training/test_post_training.py @@ -4,20 +4,38 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging +import sys +import time +import uuid + import pytest -from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.post_training import ( - Checkpoint, DataConfig, LoraFinetuningConfig, - OptimizerConfig, - PostTrainingJob, - PostTrainingJobArtifactsResponse, - PostTrainingJobStatusResponse, TrainingConfig, ) +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True) +logger = logging.getLogger(__name__) + + +@pytest.fixture(autouse=True) +def capture_output(capsys): + """Fixture to capture and display output during test execution.""" + yield + captured = capsys.readouterr() + if captured.out: + print("\nCaptured stdout:", captured.out) + if captured.err: + print("\nCaptured stderr:", captured.err) + + +# Force flush stdout to see prints immediately +sys.stdout.reconfigure(line_buffering=True) + # How to run this test: # # pytest llama_stack/providers/tests/post_training/test_post_training.py @@ -25,10 +43,31 @@ from llama_stack.apis.post_training import ( # -v -s --tb=short --disable-warnings -@pytest.mark.skip(reason="FIXME FIXME @yanxi0830 this needs to be migrated to use the API") class TestPostTraining: - @pytest.mark.asyncio - async def test_supervised_fine_tune(self, post_training_stack): + @pytest.mark.integration + @pytest.mark.parametrize( + "purpose, source", + [ + ( + "post-training/messages", + { + "type": "uri", + "uri": "huggingface://datasets/llamastack/simpleqa?split=train", + }, + ), + ], + ) + @pytest.mark.timeout(360) # 6 minutes timeout + def test_supervised_fine_tune(self, llama_stack_client, purpose, source): + logger.info("Starting supervised fine-tuning test") + + # register dataset to train + dataset = llama_stack_client.datasets.register( + purpose=purpose, + source=source, + ) + logger.info(f"Registered dataset with ID: {dataset.identifier}") + algorithm_config = LoraFinetuningConfig( type="LoRA", lora_attn_modules=["q_proj", "v_proj", "output_proj"], @@ -39,62 +78,74 @@ class TestPostTraining: ) data_config = DataConfig( - dataset_id="alpaca", + dataset_id=dataset.identifier, batch_size=1, shuffle=False, + data_format="instruct", ) - optimizer_config = OptimizerConfig( - optimizer_type="adamw", - lr=3e-4, - lr_min=3e-5, - weight_decay=0.1, - num_warmup_steps=100, - ) - + # setup training config with minimal settings training_config = TrainingConfig( n_epochs=1, data_config=data_config, - optimizer_config=optimizer_config, max_steps_per_epoch=1, gradient_accumulation_steps=1, ) - post_training_impl = post_training_stack - response = await post_training_impl.supervised_fine_tune( - job_uuid="1234", - model="Llama3.2-3B-Instruct", + + job_uuid = f"test-job{uuid.uuid4()}" + logger.info(f"Starting training job with UUID: {job_uuid}") + + # train with HF trl SFTTrainer as the default + _ = llama_stack_client.post_training.supervised_fine_tune( + job_uuid=job_uuid, + model="ibm-granite/granite-3.3-2b-instruct", algorithm_config=algorithm_config, training_config=training_config, hyperparam_search_config={}, logger_config={}, - checkpoint_dir="null", + checkpoint_dir=None, ) - assert isinstance(response, PostTrainingJob) - assert response.job_uuid == "1234" - @pytest.mark.asyncio - async def test_get_training_jobs(self, post_training_stack): - post_training_impl = post_training_stack - jobs_list = await post_training_impl.get_training_jobs() - assert isinstance(jobs_list, list) - assert jobs_list[0].job_uuid == "1234" + while True: + status = llama_stack_client.post_training.job.status(job_uuid=job_uuid) + if not status: + logger.error("Job not found") + break - @pytest.mark.asyncio - async def test_get_training_job_status(self, post_training_stack): - post_training_impl = post_training_stack - job_status = await post_training_impl.get_training_job_status("1234") - assert isinstance(job_status, PostTrainingJobStatusResponse) - assert job_status.job_uuid == "1234" - assert job_status.status == JobStatus.completed - assert isinstance(job_status.checkpoints[0], Checkpoint) + logger.info(f"Current status: {status}") + if status.status == "completed": + break - @pytest.mark.asyncio - async def test_get_training_job_artifacts(self, post_training_stack): - post_training_impl = post_training_stack - job_artifacts = await post_training_impl.get_training_job_artifacts("1234") - assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse) - assert job_artifacts.job_uuid == "1234" - assert isinstance(job_artifacts.checkpoints[0], Checkpoint) - assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0" - assert job_artifacts.checkpoints[0].epoch == 0 - assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path + logger.info("Waiting for job to complete...") + time.sleep(10) # Increased sleep time to reduce polling frequency + + artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid) + logger.info(f"Job artifacts: {artifacts}") + + # TODO: Fix these tests to properly represent the Jobs API in training + # @pytest.mark.asyncio + # async def test_get_training_jobs(self, post_training_stack): + # post_training_impl = post_training_stack + # jobs_list = await post_training_impl.get_training_jobs() + # assert isinstance(jobs_list, list) + # assert jobs_list[0].job_uuid == "1234" + + # @pytest.mark.asyncio + # async def test_get_training_job_status(self, post_training_stack): + # post_training_impl = post_training_stack + # job_status = await post_training_impl.get_training_job_status("1234") + # assert isinstance(job_status, PostTrainingJobStatusResponse) + # assert job_status.job_uuid == "1234" + # assert job_status.status == JobStatus.completed + # assert isinstance(job_status.checkpoints[0], Checkpoint) + + # @pytest.mark.asyncio + # async def test_get_training_job_artifacts(self, post_training_stack): + # post_training_impl = post_training_stack + # job_artifacts = await post_training_impl.get_training_job_artifacts("1234") + # assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse) + # assert job_artifacts.job_uuid == "1234" + # assert isinstance(job_artifacts.checkpoints[0], Checkpoint) + # assert job_artifacts.checkpoints[0].identifier == "instructlab/granite-7b-lab" + # assert job_artifacts.checkpoints[0].epoch == 0 + # assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path diff --git a/uv.lock b/uv.lock index dbf0c891f..6bd3f84d5 100644 --- a/uv.lock +++ b/uv.lock @@ -1459,6 +1459,7 @@ dev = [ { name = "pytest-cov" }, { name = "pytest-html" }, { name = "pytest-json-report" }, + { name = "pytest-timeout" }, { name = "ruamel-yaml" }, { name = "ruff" }, { name = "types-requests" }, @@ -1557,6 +1558,7 @@ requires-dist = [ { name = "pytest-cov", marker = "extra == 'dev'" }, { name = "pytest-html", marker = "extra == 'dev'" }, { name = "pytest-json-report", marker = "extra == 'dev'" }, + { name = "pytest-timeout", marker = "extra == 'dev'" }, { name = "python-dotenv" }, { name = "qdrant-client", marker = "extra == 'unit'" }, { name = "requests" }, @@ -2852,6 +2854,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3e/43/7e7b2ec865caa92f67b8f0e9231a798d102724ca4c0e1f414316be1c1ef2/pytest_metadata-3.1.1-py3-none-any.whl", hash = "sha256:c8e0844db684ee1c798cfa38908d20d67d0463ecb6137c72e91f418558dd5f4b", size = 11428, upload-time = "2024-02-12T19:38:42.531Z" }, ] +[[package]] +name = "pytest-timeout" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ac/82/4c9ecabab13363e72d880f2fb504c5f750433b2b6f16e99f4ec21ada284c/pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a", size = 17973, upload-time = "2025-05-05T19:44:34.99Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382, upload-time = "2025-05-05T19:44:33.502Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"