From f02f7b28c17439e7899999ae6271d7a4be12f20f Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Fri, 16 May 2025 17:41:28 -0400 Subject: [PATCH] feat: add huggingface post_training impl (#2132) # What does this PR do? adds an inline HF SFTTrainer provider. Alongside touchtune -- this is a super popular option for running training jobs. The config allows a user to specify some key fields such as a model, chat_template, device, etc the provider comes with one recipe `finetune_single_device` which works both with and without LoRA. any model that is a valid HF identifier can be given and the model will be pulled. this has been tested so far with CPU and MPS device types, but should be compatible with CUDA out of the box The provider processes the given dataset into the proper format, establishes the various steps per epoch, steps per save, steps per eval, sets a sane SFTConfig, and runs n_epochs of training if checkpoint_dir is none, no model is saved. If there is a checkpoint dir, a model is saved every `save_steps` and at the end of training. ## Test Plan re-enabled post_training integration test suite with a singular test that loads the simpleqa dataset: https://huggingface.co/datasets/llamastack/simpleqa and a tiny granite model: https://huggingface.co/ibm-granite/granite-3.3-2b-instruct. The test now uses the llama stack client and the proper post_training API runs one step with a batch_size of 1. This test runs on CPU on the Ubuntu runner so it needs to be a small batch and a single step. [//]: # (## Documentation) --------- Signed-off-by: Charlie Doern --- .github/workflows/integration-tests.yml | 16 +- .../self_hosted_distro/ollama.md | 1 + .../inline/post_training/common/utils.py | 35 + .../post_training/huggingface/__init__.py | 27 + .../post_training/huggingface/config.py | 72 ++ .../huggingface/post_training.py | 176 +++++ .../recipes/finetune_single_device.py | 683 ++++++++++++++++++ .../recipes/lora_finetuning_single_device.py | 8 +- .../providers/registry/post_training.py | 11 + llama_stack/templates/dependencies.json | 3 + .../experimental-post-training/build.yaml | 3 +- .../experimental-post-training/run.yaml | 16 +- llama_stack/templates/ollama/build.yaml | 2 + llama_stack/templates/ollama/ollama.py | 10 +- .../templates/ollama/run-with-safety.yaml | 8 + llama_stack/templates/ollama/run.yaml | 8 + pyproject.toml | 1 + requirements.txt | 137 ---- .../post_training/test_post_training.py | 151 ++-- uv.lock | 14 + 20 files changed, 1181 insertions(+), 201 deletions(-) create mode 100644 llama_stack/providers/inline/post_training/common/utils.py create mode 100644 llama_stack/providers/inline/post_training/huggingface/__init__.py create mode 100644 llama_stack/providers/inline/post_training/huggingface/config.py create mode 100644 llama_stack/providers/inline/post_training/huggingface/post_training.py create mode 100644 llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index d755ff0ae..c083da7d9 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -58,7 +58,7 @@ jobs: INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" run: | source .venv/bin/activate - nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv > server.log 2>&1 & + LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv & - name: Wait for Llama Stack server to be ready if: matrix.client-type == 'http' @@ -85,6 +85,11 @@ jobs: echo "Ollama health check failed" exit 1 fi + - name: Check Storage and Memory Available Before Tests + if: ${{ always() }} + run: | + free -h + df -h - name: Run Integration Tests env: @@ -100,12 +105,19 @@ jobs: --text-model="meta-llama/Llama-3.2-3B-Instruct" \ --embedding-model=all-MiniLM-L6-v2 + - name: Check Storage and Memory Available After Tests + if: ${{ always() }} + run: | + free -h + df -h + - name: Write ollama logs to file + if: ${{ always() }} run: | sudo journalctl -u ollama.service > ollama.log - name: Upload all logs to artifacts - if: always() + if: ${{ always() }} uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 with: name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }} diff --git a/docs/source/distributions/self_hosted_distro/ollama.md b/docs/source/distributions/self_hosted_distro/ollama.md index a3d67d4ce..4d148feda 100644 --- a/docs/source/distributions/self_hosted_distro/ollama.md +++ b/docs/source/distributions/self_hosted_distro/ollama.md @@ -19,6 +19,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov | datasetio | `remote::huggingface`, `inline::localfs` | | eval | `inline::meta-reference` | | inference | `remote::ollama` | +| post_training | `inline::huggingface` | | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | diff --git a/llama_stack/providers/inline/post_training/common/utils.py b/llama_stack/providers/inline/post_training/common/utils.py new file mode 100644 index 000000000..7840b21e8 --- /dev/null +++ b/llama_stack/providers/inline/post_training/common/utils.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import gc + + +def evacuate_model_from_device(model, device: str): + """Safely clear a model from memory and free device resources. + This function handles the proper cleanup of a model by: + 1. Moving the model to CPU if it's on a non-CPU device + 2. Deleting the model object to free memory + 3. Running garbage collection + 4. Clearing CUDA cache if the model was on a CUDA device + Args: + model: The PyTorch model to clear + device: The device type the model is currently on ('cuda', 'mps', 'cpu') + Note: + - For CUDA devices, this will clear the CUDA cache after moving the model to CPU + - For MPS devices, only moves the model to CPU (no cache clearing available) + - For CPU devices, only deletes the model object and runs garbage collection + """ + if device != "cpu": + model.to("cpu") + + del model + gc.collect() + + if device == "cuda": + # we need to import such that this is only imported when the method is called + import torch + + torch.cuda.empty_cache() diff --git a/llama_stack/providers/inline/post_training/huggingface/__init__.py b/llama_stack/providers/inline/post_training/huggingface/__init__.py new file mode 100644 index 000000000..cc1a671c1 --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.distribution.datatypes import Api + +from .config import HuggingFacePostTrainingConfig + +# post_training api and the huggingface provider is still experimental and under heavy development + + +async def get_provider_impl( + config: HuggingFacePostTrainingConfig, + deps: dict[Api, Any], +): + from .post_training import HuggingFacePostTrainingImpl + + impl = HuggingFacePostTrainingImpl( + config, + deps[Api.datasetio], + deps[Api.datasets], + ) + return impl diff --git a/llama_stack/providers/inline/post_training/huggingface/config.py b/llama_stack/providers/inline/post_training/huggingface/config.py new file mode 100644 index 000000000..06c6d8073 --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface/config.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Literal + +from pydantic import BaseModel + + +class HuggingFacePostTrainingConfig(BaseModel): + # Device to run training on (cuda, cpu, mps) + device: str = "cuda" + + # Distributed training backend if using multiple devices + # fsdp: Fully Sharded Data Parallel + # deepspeed: DeepSpeed ZeRO optimization + distributed_backend: Literal["fsdp", "deepspeed"] | None = None + + # Format for saving model checkpoints + # full_state: Save complete model state + # huggingface: Save in HuggingFace format (recommended for compatibility) + checkpoint_format: Literal["full_state", "huggingface"] | None = "huggingface" + + # Template for formatting chat inputs and outputs + # Used to structure the conversation format for training + chat_template: str = "<|user|>\n{input}\n<|assistant|>\n{output}" + + # Model-specific configuration parameters + # trust_remote_code: Allow execution of custom model code + # attn_implementation: Use SDPA (Scaled Dot Product Attention) for better performance + model_specific_config: dict = { + "trust_remote_code": True, + "attn_implementation": "sdpa", + } + + # Maximum sequence length for training + # Set to 2048 as this is the maximum that works reliably on MPS (Apple Silicon) + # Longer sequences may cause memory issues on MPS devices + max_seq_length: int = 2048 + + # Enable gradient checkpointing to reduce memory usage + # Trades computation for memory by recomputing activations + gradient_checkpointing: bool = False + + # Maximum number of checkpoints to keep + # Older checkpoints are deleted when this limit is reached + save_total_limit: int = 3 + + # Number of training steps between logging updates + logging_steps: int = 10 + + # Ratio of training steps used for learning rate warmup + # Helps stabilize early training + warmup_ratio: float = 0.1 + + # L2 regularization coefficient + # Helps prevent overfitting + weight_decay: float = 0.01 + + # Number of worker processes for data loading + # Higher values can improve data loading speed but increase memory usage + dataloader_num_workers: int = 4 + + # Whether to pin memory in data loader + # Can improve data transfer speed to GPU but uses more memory + dataloader_pin_memory: bool = True + + @classmethod + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: + return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"} diff --git a/llama_stack/providers/inline/post_training/huggingface/post_training.py b/llama_stack/providers/inline/post_training/huggingface/post_training.py new file mode 100644 index 000000000..0b2760792 --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface/post_training.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from enum import Enum +from typing import Any + +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.post_training import ( + AlgorithmConfig, + Checkpoint, + DPOAlignmentConfig, + JobStatus, + ListPostTrainingJobsResponse, + PostTrainingJob, + PostTrainingJobArtifactsResponse, + PostTrainingJobStatusResponse, + TrainingConfig, +) +from llama_stack.providers.inline.post_training.huggingface.config import ( + HuggingFacePostTrainingConfig, +) +from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import ( + HFFinetuningSingleDevice, +) +from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler +from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus +from llama_stack.schema_utils import webmethod + + +class TrainingArtifactType(Enum): + CHECKPOINT = "checkpoint" + RESOURCES_STATS = "resources_stats" + + +_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune" + + +class HuggingFacePostTrainingImpl: + def __init__( + self, + config: HuggingFacePostTrainingConfig, + datasetio_api: DatasetIO, + datasets: Datasets, + ) -> None: + self.config = config + self.datasetio_api = datasetio_api + self.datasets_api = datasets + self._scheduler = Scheduler() + + async def shutdown(self) -> None: + await self._scheduler.shutdown() + + @staticmethod + def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact: + return JobArtifact( + type=TrainingArtifactType.CHECKPOINT.value, + name=checkpoint.identifier, + uri=checkpoint.path, + metadata=dict(checkpoint), + ) + + @staticmethod + def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact: + return JobArtifact( + type=TrainingArtifactType.RESOURCES_STATS.value, + name=TrainingArtifactType.RESOURCES_STATS.value, + metadata=resources_stats, + ) + + async def supervised_fine_tune( + self, + job_uuid: str, + training_config: TrainingConfig, + hyperparam_search_config: dict[str, Any], + logger_config: dict[str, Any], + model: str, + checkpoint_dir: str | None = None, + algorithm_config: AlgorithmConfig | None = None, + ) -> PostTrainingJob: + async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb): + on_log_message_cb("Starting HF finetuning") + + recipe = HFFinetuningSingleDevice( + job_uuid=job_uuid, + datasetio_api=self.datasetio_api, + datasets_api=self.datasets_api, + ) + + resources_allocated, checkpoints = await recipe.train( + model=model, + output_dir=checkpoint_dir, + job_uuid=job_uuid, + lora_config=algorithm_config, + config=training_config, + provider_config=self.config, + ) + + on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated)) + if checkpoints: + for checkpoint in checkpoints: + artifact = self._checkpoint_to_artifact(checkpoint) + on_artifact_collected_cb(artifact) + + on_status_change_cb(SchedulerJobStatus.completed) + on_log_message_cb("HF finetuning completed") + + job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler) + return PostTrainingJob(job_uuid=job_uuid) + + async def preference_optimize( + self, + job_uuid: str, + finetuned_model: str, + algorithm_config: DPOAlignmentConfig, + training_config: TrainingConfig, + hyperparam_search_config: dict[str, Any], + logger_config: dict[str, Any], + ) -> PostTrainingJob: + raise NotImplementedError("DPO alignment is not implemented yet") + + async def get_training_jobs(self) -> ListPostTrainingJobsResponse: + return ListPostTrainingJobsResponse( + data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()] + ) + + @staticmethod + def _get_artifacts_metadata_by_type(job, artifact_type): + return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type] + + @classmethod + def _get_checkpoints(cls, job): + return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value) + + @classmethod + def _get_resources_allocated(cls, job): + data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value) + return data[0] if data else None + + @webmethod(route="/post-training/job/status") + async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None: + job = self._scheduler.get_job(job_uuid) + + match job.status: + # TODO: Add support for other statuses to API + case SchedulerJobStatus.new | SchedulerJobStatus.scheduled: + status = JobStatus.scheduled + case SchedulerJobStatus.running: + status = JobStatus.in_progress + case SchedulerJobStatus.completed: + status = JobStatus.completed + case SchedulerJobStatus.failed: + status = JobStatus.failed + case _: + raise NotImplementedError() + + return PostTrainingJobStatusResponse( + job_uuid=job_uuid, + status=status, + scheduled_at=job.scheduled_at, + started_at=job.started_at, + completed_at=job.completed_at, + checkpoints=self._get_checkpoints(job), + resources_allocated=self._get_resources_allocated(job), + ) + + @webmethod(route="/post-training/job/cancel") + async def cancel_training_job(self, job_uuid: str) -> None: + self._scheduler.cancel(job_uuid) + + @webmethod(route="/post-training/job/artifacts") + async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None: + job = self._scheduler.get_job(job_uuid) + return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job)) diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py new file mode 100644 index 000000000..b6d13b029 --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py @@ -0,0 +1,683 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import gc +import json +import logging +import multiprocessing +import os +import signal +import sys +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import psutil + +from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device + +# Set tokenizer parallelism environment variable +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# Force PyTorch to use OpenBLAS instead of MKL +os.environ["MKL_THREADING_LAYER"] = "GNU" +os.environ["MKL_SERVICE_FORCE_INTEL"] = "0" +os.environ["MKL_NUM_THREADS"] = "1" + +import torch +from datasets import Dataset +from peft import LoraConfig +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, +) +from trl import SFTConfig, SFTTrainer + +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.post_training import ( + Checkpoint, + DataConfig, + LoraFinetuningConfig, + TrainingConfig, +) + +from ..config import HuggingFacePostTrainingConfig + +logger = logging.getLogger(__name__) + + +def get_gb(to_convert: int) -> str: + """Converts memory stats to GB and formats to 2 decimal places. + Args: + to_convert: Memory value in bytes + Returns: + str: Memory value in GB formatted to 2 decimal places + """ + return f"{(to_convert / (1024**3)):.2f}" + + +def get_memory_stats(device: torch.device) -> dict[str, Any]: + """Get memory statistics for the given device.""" + stats = { + "system_memory": { + "total": get_gb(psutil.virtual_memory().total), + "available": get_gb(psutil.virtual_memory().available), + "used": get_gb(psutil.virtual_memory().used), + "percent": psutil.virtual_memory().percent, + } + } + + if device.type == "cuda": + stats["device_memory"] = { + "allocated": get_gb(torch.cuda.memory_allocated(device)), + "reserved": get_gb(torch.cuda.memory_reserved(device)), + "max_allocated": get_gb(torch.cuda.max_memory_allocated(device)), + } + elif device.type == "mps": + # MPS doesn't provide direct memory stats, but we can track system memory + stats["device_memory"] = { + "note": "MPS memory stats not directly available", + "system_memory_used": get_gb(psutil.virtual_memory().used), + } + elif device.type == "cpu": + # For CPU, we track process memory usage + process = psutil.Process() + stats["device_memory"] = { + "process_rss": get_gb(process.memory_info().rss), + "process_vms": get_gb(process.memory_info().vms), + "process_percent": process.memory_percent(), + } + + return stats + + +def setup_torch_device(device_str: str) -> torch.device: + """Initialize and validate a PyTorch device. + This function handles device initialization and validation for different device types: + - CUDA: Validates CUDA availability and handles device selection + - MPS: Validates MPS availability for Apple Silicon + - CPU: Basic validation + - HPU: Raises error as it's not supported + Args: + device_str: String specifying the device ('cuda', 'cpu', 'mps') + Returns: + torch.device: The initialized and validated device + Raises: + RuntimeError: If device initialization fails or device is not supported + """ + try: + device = torch.device(device_str) + except RuntimeError as e: + raise RuntimeError(f"Error getting Torch Device {str(e)}") from e + + # Validate device capabilities + if device.type == "cuda": + if not torch.cuda.is_available(): + raise RuntimeError( + f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device." + ) + if device.index is None: + device = torch.device(device.type, torch.cuda.current_device()) + elif device.type == "mps": + if not torch.backends.mps.is_available(): + raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.") + elif device.type == "hpu": + raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.") + + return device + + +class HFFinetuningSingleDevice: + def __init__( + self, + job_uuid: str, + datasetio_api: DatasetIO, + datasets_api: Datasets, + ): + self.datasetio_api = datasetio_api + self.datasets_api = datasets_api + self.job_uuid = job_uuid + + def validate_dataset_format(self, rows: list[dict]) -> bool: + """Validate that the dataset has the required fields.""" + required_fields = ["input_query", "expected_answer", "chat_completion_input"] + return all(field in row for row in rows for field in required_fields) + + def _process_instruct_format(self, row: dict) -> tuple[str | None, str | None]: + """Process a row in instruct format.""" + if "chat_completion_input" in row and "expected_answer" in row: + try: + messages = json.loads(row["chat_completion_input"]) + if not isinstance(messages, list) or len(messages) != 1: + logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}") + return None, None + if "content" not in messages[0]: + logger.warning(f"Message missing content: {messages[0]}") + return None, None + return messages[0]["content"], row["expected_answer"] + except json.JSONDecodeError: + logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}") + return None, None + return None, None + + def _process_dialog_format(self, row: dict) -> tuple[str | None, str | None]: + """Process a row in dialog format.""" + if "dialog" in row: + try: + dialog = json.loads(row["dialog"]) + if not isinstance(dialog, list) or len(dialog) < 2: + logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}") + return None, None + if dialog[0].get("role") != "user": + logger.warning(f"First message must be from user: {dialog[0]}") + return None, None + if not any(msg.get("role") == "assistant" for msg in dialog): + logger.warning("Dialog must have at least one assistant message") + return None, None + + # Convert to human/gpt format + role_map = {"user": "human", "assistant": "gpt"} + conversations = [] + for msg in dialog: + if "role" not in msg or "content" not in msg: + logger.warning(f"Message missing role or content: {msg}") + continue + conversations.append({"from": role_map[msg["role"]], "value": msg["content"]}) + + # Format as a single conversation + return conversations[0]["value"], conversations[1]["value"] + except json.JSONDecodeError: + logger.warning(f"Failed to parse dialog: {row['dialog']}") + return None, None + return None, None + + def _process_fallback_format(self, row: dict) -> tuple[str | None, str | None]: + """Process a row using fallback formats.""" + if "input" in row and "output" in row: + return row["input"], row["output"] + elif "prompt" in row and "completion" in row: + return row["prompt"], row["completion"] + elif "question" in row and "answer" in row: + return row["question"], row["answer"] + return None, None + + def _format_text(self, input_text: str, output_text: str, provider_config: HuggingFacePostTrainingConfig) -> str: + """Format input and output text based on model requirements.""" + if hasattr(provider_config, "chat_template"): + return provider_config.chat_template.format(input=input_text, output=output_text) + return f"{input_text}\n{output_text}" + + def _create_dataset( + self, rows: list[dict], config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig + ) -> Dataset: + """Create and preprocess the dataset.""" + formatted_rows = [] + for row in rows: + input_text = None + output_text = None + + # Process based on format + assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized" + if config.data_config.data_format.value == "instruct": + input_text, output_text = self._process_instruct_format(row) + elif config.data_config.data_format.value == "dialog": + input_text, output_text = self._process_dialog_format(row) + else: + input_text, output_text = self._process_fallback_format(row) + + if input_text and output_text: + formatted_text = self._format_text(input_text, output_text, provider_config) + formatted_rows.append({"text": formatted_text}) + + if not formatted_rows: + assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized" + raise ValueError( + f"No valid input/output pairs found in the dataset for format: {config.data_config.data_format.value}" + ) + + return Dataset.from_list(formatted_rows) + + def _preprocess_dataset( + self, ds: Dataset, tokenizer: AutoTokenizer, provider_config: HuggingFacePostTrainingConfig + ) -> Dataset: + """Preprocess the dataset with tokenizer.""" + + def tokenize_function(examples): + return tokenizer( + examples["text"], + padding=True, + truncation=True, + max_length=provider_config.max_seq_length, + return_tensors=None, + ) + + return ds.map( + tokenize_function, + batched=True, + remove_columns=ds.column_names, + ) + + async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]: + """Load dataset from llama stack dataset provider""" + try: + all_rows = await self.datasetio_api.iterrows( + dataset_id=dataset_id, + limit=-1, + ) + if not isinstance(all_rows.data, list): + raise RuntimeError("Expected dataset data to be a list") + return all_rows.data + except Exception as e: + raise RuntimeError(f"Failed to load dataset: {str(e)}") from e + + def _run_training_sync( + self, + model: str, + provider_config: dict[str, Any], + peft_config: LoraConfig | None, + config: dict[str, Any], + output_dir_path: Path | None, + ) -> None: + """Synchronous wrapper for running training process. + This method serves as a bridge between the multiprocessing Process and the async training function. + It creates a new event loop to run the async training process. + Args: + model: The model identifier to load + dataset_id: ID of the dataset to use for training + provider_config: Configuration specific to the HuggingFace provider + peft_config: Optional LoRA configuration + config: General training configuration + output_dir_path: Optional path to save the model + """ + import asyncio + + logger.info("Starting training process with async wrapper") + asyncio.run( + self._run_training( + model=model, + provider_config=provider_config, + peft_config=peft_config, + config=config, + output_dir_path=output_dir_path, + ) + ) + + async def load_dataset( + self, + model: str, + config: TrainingConfig, + provider_config: HuggingFacePostTrainingConfig, + ) -> tuple[Dataset, Dataset, AutoTokenizer]: + """Load and prepare the dataset for training. + Args: + model: The model identifier to load + config: Training configuration + provider_config: Provider-specific configuration + Returns: + tuple: (train_dataset, eval_dataset, tokenizer) + """ + # Validate data config + if not config.data_config: + raise ValueError("DataConfig is required for training") + + # Load dataset + logger.info(f"Loading dataset: {config.data_config.dataset_id}") + rows = await self._setup_data(config.data_config.dataset_id) + if not self.validate_dataset_format(rows): + raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input") + logger.info(f"Loaded {len(rows)} rows from dataset") + + # Initialize tokenizer + logger.info(f"Initializing tokenizer for model: {model}") + try: + tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config) + + # Set pad token to eos token if not present + # This is common for models that don't have a dedicated pad token + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + + # Set padding side to right for causal language modeling + # This ensures that padding tokens don't interfere with the model's ability + # to predict the next token in the sequence + tokenizer.padding_side = "right" + + # Set truncation side to right to keep the beginning of the sequence + # This is important for maintaining context and instruction format + tokenizer.truncation_side = "right" + + # Set model max length to match provider config + # This ensures consistent sequence lengths across the training process + tokenizer.model_max_length = provider_config.max_seq_length + + logger.info("Tokenizer initialized successfully") + except Exception as e: + raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e + + # Create and preprocess dataset + logger.info("Creating and preprocessing dataset") + try: + ds = self._create_dataset(rows, config, provider_config) + ds = self._preprocess_dataset(ds, tokenizer, provider_config) + logger.info(f"Dataset created with {len(ds)} examples") + except Exception as e: + raise ValueError(f"Failed to create dataset: {str(e)}") from e + + # Split dataset + logger.info("Splitting dataset into train and validation sets") + train_val_split = ds.train_test_split(test_size=0.1, seed=42) + train_dataset = train_val_split["train"] + eval_dataset = train_val_split["test"] + logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples") + + return train_dataset, eval_dataset, tokenizer + + def load_model( + self, + model: str, + device: torch.device, + provider_config: HuggingFacePostTrainingConfig, + ) -> AutoModelForCausalLM: + """Load and initialize the model for training. + Args: + model: The model identifier to load + device: The device to load the model onto + provider_config: Provider-specific configuration + Returns: + The loaded and initialized model + Raises: + RuntimeError: If model loading fails + """ + logger.info("Loading the base model") + try: + model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config) + model_obj = AutoModelForCausalLM.from_pretrained( + model, + torch_dtype="auto" if device.type != "cpu" else "float32", + quantization_config=None, + config=model_config, + **provider_config.model_specific_config, + ) + # Always move model to specified device + model_obj = model_obj.to(device) + logger.info(f"Model loaded and moved to device: {model_obj.device}") + return model_obj + except Exception as e: + raise RuntimeError(f"Failed to load model: {str(e)}") from e + + def setup_training_args( + self, + config: TrainingConfig, + provider_config: HuggingFacePostTrainingConfig, + device: torch.device, + output_dir_path: Path | None, + steps_per_epoch: int, + ) -> SFTConfig: + """Setup training arguments. + Args: + config: Training configuration + provider_config: Provider-specific configuration + device: The device to train on + output_dir_path: Optional path to save the model + steps_per_epoch: Number of steps per epoch + Returns: + Configured SFTConfig object + """ + logger.info("Configuring training arguments") + lr = 2e-5 + if config.optimizer_config: + lr = config.optimizer_config.lr + logger.info(f"Using custom learning rate: {lr}") + + # Validate data config + if not config.data_config: + raise ValueError("DataConfig is required for training") + data_config = config.data_config + + # Calculate steps + total_steps = steps_per_epoch * config.n_epochs + max_steps = min(config.max_steps_per_epoch, total_steps) + eval_steps = max(1, steps_per_epoch // 10) # Evaluate 10 times per epoch + save_steps = max(1, steps_per_epoch // 5) # Save 5 times per epoch + logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch + + logger.info("Training configuration:") + logger.info(f"- Steps per epoch: {steps_per_epoch}") + logger.info(f"- Total steps: {total_steps}") + logger.info(f"- Max steps: {max_steps}") + logger.info(f"- Eval steps: {eval_steps}") + logger.info(f"- Save steps: {save_steps}") + logger.info(f"- Logging steps: {logging_steps}") + + # Configure save strategy + save_strategy = "no" + if output_dir_path: + save_strategy = "steps" + logger.info(f"Will save checkpoints to {output_dir_path}") + + return SFTConfig( + max_steps=max_steps, + output_dir=str(output_dir_path) if output_dir_path is not None else None, + num_train_epochs=config.n_epochs, + per_device_train_batch_size=data_config.batch_size, + fp16=device.type == "cuda", + bf16=False, # Causes CPU issues. + eval_strategy="steps", + use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False, + save_strategy=save_strategy, + report_to="none", + max_seq_length=provider_config.max_seq_length, + gradient_accumulation_steps=config.gradient_accumulation_steps, + gradient_checkpointing=provider_config.gradient_checkpointing, + learning_rate=lr, + warmup_ratio=provider_config.warmup_ratio, + weight_decay=provider_config.weight_decay, + remove_unused_columns=False, + dataloader_pin_memory=provider_config.dataloader_pin_memory, + dataloader_num_workers=provider_config.dataloader_num_workers, + dataset_text_field="text", + packing=False, + load_best_model_at_end=True if output_dir_path else False, + metric_for_best_model="eval_loss", + greater_is_better=False, + eval_steps=eval_steps, + save_steps=save_steps, + logging_steps=logging_steps, + ) + + def save_model( + self, + model_obj: AutoModelForCausalLM, + trainer: SFTTrainer, + peft_config: LoraConfig | None, + output_dir_path: Path, + ) -> None: + """Save the trained model. + Args: + model_obj: The model to save + trainer: The trainer instance + peft_config: Optional LoRA configuration + output_dir_path: Path to save the model + """ + logger.info("Saving final model") + model_obj.config.use_cache = True + + if peft_config: + logger.info("Merging LoRA weights with base model") + model_obj = trainer.model.merge_and_unload() + else: + model_obj = trainer.model + + save_path = output_dir_path / "merged_model" + logger.info(f"Saving model to {save_path}") + model_obj.save_pretrained(save_path) + + async def _run_training( + self, + model: str, + provider_config: dict[str, Any], + peft_config: LoraConfig | None, + config: dict[str, Any], + output_dir_path: Path | None, + ) -> None: + """Run the training process with signal handling.""" + + def signal_handler(signum, frame): + """Handle termination signals gracefully.""" + logger.info(f"Received signal {signum}, initiating graceful shutdown") + sys.exit(0) + + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + # Convert config dicts back to objects + logger.info("Initializing configuration objects") + provider_config_obj = HuggingFacePostTrainingConfig(**provider_config) + config_obj = TrainingConfig(**config) + + # Initialize and validate device + device = setup_torch_device(provider_config_obj.device) + logger.info(f"Using device '{device}'") + + # Load dataset and tokenizer + train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj) + + # Calculate steps per epoch + if not config_obj.data_config: + raise ValueError("DataConfig is required for training") + steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size + + # Setup training arguments + training_args = self.setup_training_args( + config_obj, + provider_config_obj, + device, + output_dir_path, + steps_per_epoch, + ) + + # Load model + model_obj = self.load_model(model, device, provider_config_obj) + + # Initialize trainer + logger.info("Initializing SFTTrainer") + trainer = SFTTrainer( + model=model_obj, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=peft_config, + args=training_args, + ) + + try: + # Train + logger.info("Starting training") + trainer.train() + logger.info("Training completed successfully") + + # Save final model if output directory is provided + if output_dir_path: + self.save_model(model_obj, trainer, peft_config, output_dir_path) + + finally: + # Clean up resources + logger.info("Cleaning up resources") + if hasattr(trainer, "model"): + evacuate_model_from_device(trainer.model, device.type) + del trainer + gc.collect() + logger.info("Cleanup completed") + + async def train( + self, + model: str, + output_dir: str | None, + job_uuid: str, + lora_config: LoraFinetuningConfig, + config: TrainingConfig, + provider_config: HuggingFacePostTrainingConfig, + ) -> tuple[dict[str, Any], list[Checkpoint] | None]: + """Train a model using HuggingFace's SFTTrainer""" + # Initialize and validate device + device = setup_torch_device(provider_config.device) + logger.info(f"Using device '{device}'") + + output_dir_path = None + if output_dir: + output_dir_path = Path(output_dir) + + # Track memory stats + memory_stats = { + "initial": get_memory_stats(device), + "after_training": None, + "final": None, + } + + # Configure LoRA + peft_config = None + if lora_config: + peft_config = LoraConfig( + lora_alpha=lora_config.alpha, + lora_dropout=0.1, + r=lora_config.rank, + bias="none", + task_type="CAUSAL_LM", + target_modules=lora_config.lora_attn_modules, + ) + + # Validate data config + if not config.data_config: + raise ValueError("DataConfig is required for training") + + # Train in a separate process + logger.info("Starting training in separate process") + try: + # Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility + if device.type in ["cuda", "mps"]: + multiprocessing.set_start_method("spawn", force=True) + + process = multiprocessing.Process( + target=self._run_training_sync, + kwargs={ + "model": model, + "provider_config": provider_config.model_dump(), + "peft_config": peft_config, + "config": config.model_dump(), + "output_dir_path": output_dir_path, + }, + ) + process.start() + + # Monitor the process + while process.is_alive(): + process.join(timeout=1) # Check every second + if not process.is_alive(): + break + + # Get the return code + if process.exitcode != 0: + raise RuntimeError(f"Training failed with exit code {process.exitcode}") + + memory_stats["after_training"] = get_memory_stats(device) + + checkpoints = None + if output_dir_path: + # Create checkpoint + checkpoint = Checkpoint( + identifier=f"{model}-sft-{config.n_epochs}", + created_at=datetime.now(timezone.utc), + epoch=config.n_epochs, + post_training_job_id=job_uuid, + path=str(output_dir_path / "merged_model"), + ) + checkpoints = [checkpoint] + + return memory_stats, checkpoints + finally: + memory_stats["final"] = get_memory_stats(device) + gc.collect() diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index b5a495935..f56dd2499 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import gc import logging import os import time @@ -47,6 +46,7 @@ from llama_stack.apis.post_training import ( from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.models.llama.sku_list import resolve_model +from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from llama_stack.providers.inline.post_training.torchtune.common import utils from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( TorchtuneCheckpointer, @@ -554,11 +554,7 @@ class LoraFinetuningSingleDevice: checkpoints.append(checkpoint) # clean up the memory after training finishes - if self._device.type != "cpu": - self._model.to("cpu") - torch.cuda.empty_cache() - del self._model - gc.collect() + evacuate_model_from_device(self._model, self._device.type) return (memory_stats, checkpoints) diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index 35567c07d..d752b8819 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -21,6 +21,17 @@ def available_providers() -> list[ProviderSpec]: Api.datasets, ], ), + InlineProviderSpec( + api=Api.post_training, + provider_type="inline::huggingface", + pip_packages=["torch", "trl", "transformers", "peft", "datasets"], + module="llama_stack.providers.inline.post_training.huggingface", + config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig", + api_dependencies=[ + Api.datasetio, + Api.datasets, + ], + ), remote_provider_spec( api=Api.post_training, adapter=AdapterSpec( diff --git a/llama_stack/templates/dependencies.json b/llama_stack/templates/dependencies.json index d1a17e48e..fb4ab9fda 100644 --- a/llama_stack/templates/dependencies.json +++ b/llama_stack/templates/dependencies.json @@ -441,6 +441,7 @@ "opentelemetry-exporter-otlp-proto-http", "opentelemetry-sdk", "pandas", + "peft", "pillow", "psycopg2-binary", "pymongo", @@ -451,9 +452,11 @@ "scikit-learn", "scipy", "sentencepiece", + "torch", "tqdm", "transformers", "tree_sitter", + "trl", "uvicorn" ], "open-benchmark": [ diff --git a/llama_stack/templates/experimental-post-training/build.yaml b/llama_stack/templates/experimental-post-training/build.yaml index b4b5e2203..55cd189c6 100644 --- a/llama_stack/templates/experimental-post-training/build.yaml +++ b/llama_stack/templates/experimental-post-training/build.yaml @@ -13,9 +13,10 @@ distribution_spec: - inline::basic - inline::braintrust post_training: - - inline::torchtune + - inline::huggingface datasetio: - inline::localfs + - remote::huggingface telemetry: - inline::meta-reference agents: diff --git a/llama_stack/templates/experimental-post-training/run.yaml b/llama_stack/templates/experimental-post-training/run.yaml index 2ebdfe1aa..393cba41d 100644 --- a/llama_stack/templates/experimental-post-training/run.yaml +++ b/llama_stack/templates/experimental-post-training/run.yaml @@ -49,16 +49,24 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/localfs_datasetio.db + - provider_id: huggingface + provider_type: remote::huggingface + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/huggingface}/huggingface_datasetio.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference config: {} post_training: - - provider_id: torchtune-post-training - provider_type: inline::torchtune - config: { + - provider_id: huggingface + provider_type: inline::huggingface + config: checkpoint_format: huggingface - } + distributed_backend: null + device: cpu agents: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml index 88e61bf8a..7d5363575 100644 --- a/llama_stack/templates/ollama/build.yaml +++ b/llama_stack/templates/ollama/build.yaml @@ -23,6 +23,8 @@ distribution_spec: - inline::basic - inline::llm-as-judge - inline::braintrust + post_training: + - inline::huggingface tool_runtime: - remote::brave-search - remote::tavily-search diff --git a/llama_stack/templates/ollama/ollama.py b/llama_stack/templates/ollama/ollama.py index d72d299ec..0b4f05128 100644 --- a/llama_stack/templates/ollama/ollama.py +++ b/llama_stack/templates/ollama/ollama.py @@ -13,6 +13,7 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) +from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -28,6 +29,7 @@ def get_distribution_template() -> DistributionTemplate: "eval": ["inline::meta-reference"], "datasetio": ["remote::huggingface", "inline::localfs"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "post_training": ["inline::huggingface"], "tool_runtime": [ "remote::brave-search", "remote::tavily-search", @@ -47,7 +49,11 @@ def get_distribution_template() -> DistributionTemplate: provider_type="inline::faiss", config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), ) - + posttraining_provider = Provider( + provider_id="huggingface", + provider_type="inline::huggingface", + config=HuggingFacePostTrainingConfig.sample_run_config(f"~/.llama/distributions/{name}"), + ) inference_model = ModelInput( model_id="${env.INFERENCE_MODEL}", provider_id="ollama", @@ -92,6 +98,7 @@ def get_distribution_template() -> DistributionTemplate: provider_overrides={ "inference": [inference_provider], "vector_io": [vector_io_provider_faiss], + "post_training": [posttraining_provider], }, default_models=[inference_model, embedding_model], default_tool_groups=default_tool_groups, @@ -100,6 +107,7 @@ def get_distribution_template() -> DistributionTemplate: provider_overrides={ "inference": [inference_provider], "vector_io": [vector_io_provider_faiss], + "post_training": [posttraining_provider], "safety": [ Provider( provider_id="llama-guard", diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml index 651d58117..74d0822ca 100644 --- a/llama_stack/templates/ollama/run-with-safety.yaml +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -5,6 +5,7 @@ apis: - datasetio - eval - inference +- post_training - safety - scoring - telemetry @@ -80,6 +81,13 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + post_training: + - provider_id: huggingface + provider_type: inline::huggingface + config: + checkpoint_format: huggingface + distributed_backend: null + device: cpu tool_runtime: - provider_id: brave-search provider_type: remote::brave-search diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index 1372486fe..71229be97 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -5,6 +5,7 @@ apis: - datasetio - eval - inference +- post_training - safety - scoring - telemetry @@ -78,6 +79,13 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + post_training: + - provider_id: huggingface + provider_type: inline::huggingface + config: + checkpoint_format: huggingface + distributed_backend: null + device: cpu tool_runtime: - provider_id: brave-search provider_type: remote::brave-search diff --git a/pyproject.toml b/pyproject.toml index ba7c2300a..1fe64f350 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ [project.optional-dependencies] dev = [ "pytest", + "pytest-timeout", "pytest-asyncio", "pytest-cov", "pytest-html", diff --git a/requirements.txt b/requirements.txt index 0857a9886..c2571025a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,206 +1,69 @@ # This file was autogenerated by uv via the following command: # uv export --frozen --no-hashes --no-emit-project --output-file=requirements.txt annotated-types==0.7.0 - # via pydantic anyio==4.8.0 - # via - # httpx - # llama-stack-client - # openai attrs==25.1.0 - # via - # jsonschema - # referencing blobfile==3.0.0 - # via llama-stack cachetools==5.5.2 - # via google-auth certifi==2025.1.31 - # via - # httpcore - # httpx - # kubernetes - # requests charset-normalizer==3.4.1 - # via requests click==8.1.8 - # via llama-stack-client colorama==0.4.6 ; sys_platform == 'win32' - # via - # click - # tqdm distro==1.9.0 - # via - # llama-stack-client - # openai durationpy==0.9 - # via kubernetes exceptiongroup==1.2.2 ; python_full_version < '3.11' - # via anyio filelock==3.17.0 - # via - # blobfile - # huggingface-hub fire==0.7.0 - # via llama-stack fsspec==2024.12.0 - # via huggingface-hub google-auth==2.38.0 - # via kubernetes h11==0.16.0 - # via - # httpcore - # llama-stack httpcore==1.0.9 - # via httpx httpx==0.28.1 - # via - # llama-stack - # llama-stack-client - # openai huggingface-hub==0.29.0 - # via llama-stack idna==3.10 - # via - # anyio - # httpx - # requests jinja2==3.1.6 - # via llama-stack jiter==0.8.2 - # via openai jsonschema==4.23.0 - # via llama-stack jsonschema-specifications==2024.10.1 - # via jsonschema kubernetes==32.0.1 - # via llama-stack llama-stack-client==0.2.7 - # via llama-stack lxml==5.3.1 - # via blobfile markdown-it-py==3.0.0 - # via rich markupsafe==3.0.2 - # via jinja2 mdurl==0.1.2 - # via markdown-it-py numpy==2.2.3 - # via pandas oauthlib==3.2.2 - # via - # kubernetes - # requests-oauthlib openai==1.71.0 - # via llama-stack packaging==24.2 - # via huggingface-hub pandas==2.2.3 - # via llama-stack-client pillow==11.1.0 - # via llama-stack prompt-toolkit==3.0.50 - # via - # llama-stack - # llama-stack-client pyaml==25.1.0 - # via llama-stack-client pyasn1==0.6.1 - # via - # pyasn1-modules - # rsa pyasn1-modules==0.4.2 - # via google-auth pycryptodomex==3.21.0 - # via blobfile pydantic==2.10.6 - # via - # llama-stack - # llama-stack-client - # openai pydantic-core==2.27.2 - # via pydantic pygments==2.19.1 - # via rich python-dateutil==2.9.0.post0 - # via - # kubernetes - # pandas python-dotenv==1.0.1 - # via llama-stack pytz==2025.1 - # via pandas pyyaml==6.0.2 - # via - # huggingface-hub - # kubernetes - # pyaml referencing==0.36.2 - # via - # jsonschema - # jsonschema-specifications regex==2024.11.6 - # via tiktoken requests==2.32.3 - # via - # huggingface-hub - # kubernetes - # llama-stack - # requests-oauthlib - # tiktoken requests-oauthlib==2.0.0 - # via kubernetes rich==13.9.4 - # via - # llama-stack - # llama-stack-client rpds-py==0.22.3 - # via - # jsonschema - # referencing rsa==4.9 - # via google-auth setuptools==75.8.0 - # via llama-stack six==1.17.0 - # via - # kubernetes - # python-dateutil sniffio==1.3.1 - # via - # anyio - # llama-stack-client - # openai termcolor==2.5.0 - # via - # fire - # llama-stack - # llama-stack-client tiktoken==0.9.0 - # via llama-stack tqdm==4.67.1 - # via - # huggingface-hub - # llama-stack-client - # openai typing-extensions==4.12.2 - # via - # anyio - # huggingface-hub - # llama-stack-client - # openai - # pydantic - # pydantic-core - # referencing - # rich tzdata==2025.1 - # via pandas urllib3==2.3.0 - # via - # blobfile - # kubernetes - # requests wcwidth==0.2.13 - # via prompt-toolkit websocket-client==1.8.0 - # via kubernetes diff --git a/tests/integration/post_training/test_post_training.py b/tests/integration/post_training/test_post_training.py index 648ace9d6..bb4639d17 100644 --- a/tests/integration/post_training/test_post_training.py +++ b/tests/integration/post_training/test_post_training.py @@ -4,20 +4,38 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging +import sys +import time +import uuid + import pytest -from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.post_training import ( - Checkpoint, DataConfig, LoraFinetuningConfig, - OptimizerConfig, - PostTrainingJob, - PostTrainingJobArtifactsResponse, - PostTrainingJobStatusResponse, TrainingConfig, ) +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True) +logger = logging.getLogger(__name__) + + +@pytest.fixture(autouse=True) +def capture_output(capsys): + """Fixture to capture and display output during test execution.""" + yield + captured = capsys.readouterr() + if captured.out: + print("\nCaptured stdout:", captured.out) + if captured.err: + print("\nCaptured stderr:", captured.err) + + +# Force flush stdout to see prints immediately +sys.stdout.reconfigure(line_buffering=True) + # How to run this test: # # pytest llama_stack/providers/tests/post_training/test_post_training.py @@ -25,10 +43,31 @@ from llama_stack.apis.post_training import ( # -v -s --tb=short --disable-warnings -@pytest.mark.skip(reason="FIXME FIXME @yanxi0830 this needs to be migrated to use the API") class TestPostTraining: - @pytest.mark.asyncio - async def test_supervised_fine_tune(self, post_training_stack): + @pytest.mark.integration + @pytest.mark.parametrize( + "purpose, source", + [ + ( + "post-training/messages", + { + "type": "uri", + "uri": "huggingface://datasets/llamastack/simpleqa?split=train", + }, + ), + ], + ) + @pytest.mark.timeout(360) # 6 minutes timeout + def test_supervised_fine_tune(self, llama_stack_client, purpose, source): + logger.info("Starting supervised fine-tuning test") + + # register dataset to train + dataset = llama_stack_client.datasets.register( + purpose=purpose, + source=source, + ) + logger.info(f"Registered dataset with ID: {dataset.identifier}") + algorithm_config = LoraFinetuningConfig( type="LoRA", lora_attn_modules=["q_proj", "v_proj", "output_proj"], @@ -39,62 +78,74 @@ class TestPostTraining: ) data_config = DataConfig( - dataset_id="alpaca", + dataset_id=dataset.identifier, batch_size=1, shuffle=False, + data_format="instruct", ) - optimizer_config = OptimizerConfig( - optimizer_type="adamw", - lr=3e-4, - lr_min=3e-5, - weight_decay=0.1, - num_warmup_steps=100, - ) - + # setup training config with minimal settings training_config = TrainingConfig( n_epochs=1, data_config=data_config, - optimizer_config=optimizer_config, max_steps_per_epoch=1, gradient_accumulation_steps=1, ) - post_training_impl = post_training_stack - response = await post_training_impl.supervised_fine_tune( - job_uuid="1234", - model="Llama3.2-3B-Instruct", + + job_uuid = f"test-job{uuid.uuid4()}" + logger.info(f"Starting training job with UUID: {job_uuid}") + + # train with HF trl SFTTrainer as the default + _ = llama_stack_client.post_training.supervised_fine_tune( + job_uuid=job_uuid, + model="ibm-granite/granite-3.3-2b-instruct", algorithm_config=algorithm_config, training_config=training_config, hyperparam_search_config={}, logger_config={}, - checkpoint_dir="null", + checkpoint_dir=None, ) - assert isinstance(response, PostTrainingJob) - assert response.job_uuid == "1234" - @pytest.mark.asyncio - async def test_get_training_jobs(self, post_training_stack): - post_training_impl = post_training_stack - jobs_list = await post_training_impl.get_training_jobs() - assert isinstance(jobs_list, list) - assert jobs_list[0].job_uuid == "1234" + while True: + status = llama_stack_client.post_training.job.status(job_uuid=job_uuid) + if not status: + logger.error("Job not found") + break - @pytest.mark.asyncio - async def test_get_training_job_status(self, post_training_stack): - post_training_impl = post_training_stack - job_status = await post_training_impl.get_training_job_status("1234") - assert isinstance(job_status, PostTrainingJobStatusResponse) - assert job_status.job_uuid == "1234" - assert job_status.status == JobStatus.completed - assert isinstance(job_status.checkpoints[0], Checkpoint) + logger.info(f"Current status: {status}") + if status.status == "completed": + break - @pytest.mark.asyncio - async def test_get_training_job_artifacts(self, post_training_stack): - post_training_impl = post_training_stack - job_artifacts = await post_training_impl.get_training_job_artifacts("1234") - assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse) - assert job_artifacts.job_uuid == "1234" - assert isinstance(job_artifacts.checkpoints[0], Checkpoint) - assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0" - assert job_artifacts.checkpoints[0].epoch == 0 - assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path + logger.info("Waiting for job to complete...") + time.sleep(10) # Increased sleep time to reduce polling frequency + + artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid) + logger.info(f"Job artifacts: {artifacts}") + + # TODO: Fix these tests to properly represent the Jobs API in training + # @pytest.mark.asyncio + # async def test_get_training_jobs(self, post_training_stack): + # post_training_impl = post_training_stack + # jobs_list = await post_training_impl.get_training_jobs() + # assert isinstance(jobs_list, list) + # assert jobs_list[0].job_uuid == "1234" + + # @pytest.mark.asyncio + # async def test_get_training_job_status(self, post_training_stack): + # post_training_impl = post_training_stack + # job_status = await post_training_impl.get_training_job_status("1234") + # assert isinstance(job_status, PostTrainingJobStatusResponse) + # assert job_status.job_uuid == "1234" + # assert job_status.status == JobStatus.completed + # assert isinstance(job_status.checkpoints[0], Checkpoint) + + # @pytest.mark.asyncio + # async def test_get_training_job_artifacts(self, post_training_stack): + # post_training_impl = post_training_stack + # job_artifacts = await post_training_impl.get_training_job_artifacts("1234") + # assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse) + # assert job_artifacts.job_uuid == "1234" + # assert isinstance(job_artifacts.checkpoints[0], Checkpoint) + # assert job_artifacts.checkpoints[0].identifier == "instructlab/granite-7b-lab" + # assert job_artifacts.checkpoints[0].epoch == 0 + # assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path diff --git a/uv.lock b/uv.lock index dbf0c891f..6bd3f84d5 100644 --- a/uv.lock +++ b/uv.lock @@ -1459,6 +1459,7 @@ dev = [ { name = "pytest-cov" }, { name = "pytest-html" }, { name = "pytest-json-report" }, + { name = "pytest-timeout" }, { name = "ruamel-yaml" }, { name = "ruff" }, { name = "types-requests" }, @@ -1557,6 +1558,7 @@ requires-dist = [ { name = "pytest-cov", marker = "extra == 'dev'" }, { name = "pytest-html", marker = "extra == 'dev'" }, { name = "pytest-json-report", marker = "extra == 'dev'" }, + { name = "pytest-timeout", marker = "extra == 'dev'" }, { name = "python-dotenv" }, { name = "qdrant-client", marker = "extra == 'unit'" }, { name = "requests" }, @@ -2852,6 +2854,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3e/43/7e7b2ec865caa92f67b8f0e9231a798d102724ca4c0e1f414316be1c1ef2/pytest_metadata-3.1.1-py3-none-any.whl", hash = "sha256:c8e0844db684ee1c798cfa38908d20d67d0463ecb6137c72e91f418558dd5f4b", size = 11428, upload-time = "2024-02-12T19:38:42.531Z" }, ] +[[package]] +name = "pytest-timeout" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ac/82/4c9ecabab13363e72d880f2fb504c5f750433b2b6f16e99f4ec21ada284c/pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a", size = 17973, upload-time = "2025-05-05T19:44:34.99Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382, upload-time = "2025-05-05T19:44:33.502Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"