feat: add huggingface post_training impl (#2132)

# What does this PR do?


adds an inline HF SFTTrainer provider. Alongside touchtune -- this is a
super popular option for running training jobs. The config allows a user
to specify some key fields such as a model, chat_template, device, etc

the provider comes with one recipe `finetune_single_device` which works
both with and without LoRA.

any model that is a valid HF identifier can be given and the model will
be pulled.

this has been tested so far with CPU and MPS device types, but should be
compatible with CUDA out of the box

The provider processes the given dataset into the proper format,
establishes the various steps per epoch, steps per save, steps per eval,
sets a sane SFTConfig, and runs n_epochs of training

if checkpoint_dir is none, no model is saved. If there is a checkpoint
dir, a model is saved every `save_steps` and at the end of training.


## Test Plan

re-enabled post_training integration test suite with a singular test
that loads the simpleqa dataset:
https://huggingface.co/datasets/llamastack/simpleqa and a tiny granite
model: https://huggingface.co/ibm-granite/granite-3.3-2b-instruct. The
test now uses the llama stack client and the proper post_training API

runs one step with a batch_size of 1. This test runs on CPU on the
Ubuntu runner so it needs to be a small batch and a single step.

[//]: # (## Documentation)

---------

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-05-16 17:41:28 -04:00 committed by GitHub
parent 8f9964f46b
commit f02f7b28c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 1181 additions and 201 deletions

View file

@ -58,7 +58,7 @@ jobs:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
run: | run: |
source .venv/bin/activate source .venv/bin/activate
nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv > server.log 2>&1 & LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv &
- name: Wait for Llama Stack server to be ready - name: Wait for Llama Stack server to be ready
if: matrix.client-type == 'http' if: matrix.client-type == 'http'
@ -85,6 +85,11 @@ jobs:
echo "Ollama health check failed" echo "Ollama health check failed"
exit 1 exit 1
fi fi
- name: Check Storage and Memory Available Before Tests
if: ${{ always() }}
run: |
free -h
df -h
- name: Run Integration Tests - name: Run Integration Tests
env: env:
@ -100,12 +105,19 @@ jobs:
--text-model="meta-llama/Llama-3.2-3B-Instruct" \ --text-model="meta-llama/Llama-3.2-3B-Instruct" \
--embedding-model=all-MiniLM-L6-v2 --embedding-model=all-MiniLM-L6-v2
- name: Check Storage and Memory Available After Tests
if: ${{ always() }}
run: |
free -h
df -h
- name: Write ollama logs to file - name: Write ollama logs to file
if: ${{ always() }}
run: | run: |
sudo journalctl -u ollama.service > ollama.log sudo journalctl -u ollama.service > ollama.log
- name: Upload all logs to artifacts - name: Upload all logs to artifacts
if: always() if: ${{ always() }}
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with: with:
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }} name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}

View file

@ -19,6 +19,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
| datasetio | `remote::huggingface`, `inline::localfs` | | datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` | | eval | `inline::meta-reference` |
| inference | `remote::ollama` | | inference | `remote::ollama` |
| post_training | `inline::huggingface` |
| safety | `inline::llama-guard` | | safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import gc
def evacuate_model_from_device(model, device: str):
"""Safely clear a model from memory and free device resources.
This function handles the proper cleanup of a model by:
1. Moving the model to CPU if it's on a non-CPU device
2. Deleting the model object to free memory
3. Running garbage collection
4. Clearing CUDA cache if the model was on a CUDA device
Args:
model: The PyTorch model to clear
device: The device type the model is currently on ('cuda', 'mps', 'cpu')
Note:
- For CUDA devices, this will clear the CUDA cache after moving the model to CPU
- For MPS devices, only moves the model to CPU (no cache clearing available)
- For CPU devices, only deletes the model object and runs garbage collection
"""
if device != "cpu":
model.to("cpu")
del model
gc.collect()
if device == "cuda":
# we need to import such that this is only imported when the method is called
import torch
torch.cuda.empty_cache()

View file

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.distribution.datatypes import Api
from .config import HuggingFacePostTrainingConfig
# post_training api and the huggingface provider is still experimental and under heavy development
async def get_provider_impl(
config: HuggingFacePostTrainingConfig,
deps: dict[Api, Any],
):
from .post_training import HuggingFacePostTrainingImpl
impl = HuggingFacePostTrainingImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
)
return impl

View file

@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Literal
from pydantic import BaseModel
class HuggingFacePostTrainingConfig(BaseModel):
# Device to run training on (cuda, cpu, mps)
device: str = "cuda"
# Distributed training backend if using multiple devices
# fsdp: Fully Sharded Data Parallel
# deepspeed: DeepSpeed ZeRO optimization
distributed_backend: Literal["fsdp", "deepspeed"] | None = None
# Format for saving model checkpoints
# full_state: Save complete model state
# huggingface: Save in HuggingFace format (recommended for compatibility)
checkpoint_format: Literal["full_state", "huggingface"] | None = "huggingface"
# Template for formatting chat inputs and outputs
# Used to structure the conversation format for training
chat_template: str = "<|user|>\n{input}\n<|assistant|>\n{output}"
# Model-specific configuration parameters
# trust_remote_code: Allow execution of custom model code
# attn_implementation: Use SDPA (Scaled Dot Product Attention) for better performance
model_specific_config: dict = {
"trust_remote_code": True,
"attn_implementation": "sdpa",
}
# Maximum sequence length for training
# Set to 2048 as this is the maximum that works reliably on MPS (Apple Silicon)
# Longer sequences may cause memory issues on MPS devices
max_seq_length: int = 2048
# Enable gradient checkpointing to reduce memory usage
# Trades computation for memory by recomputing activations
gradient_checkpointing: bool = False
# Maximum number of checkpoints to keep
# Older checkpoints are deleted when this limit is reached
save_total_limit: int = 3
# Number of training steps between logging updates
logging_steps: int = 10
# Ratio of training steps used for learning rate warmup
# Helps stabilize early training
warmup_ratio: float = 0.1
# L2 regularization coefficient
# Helps prevent overfitting
weight_decay: float = 0.01
# Number of worker processes for data loading
# Higher values can improve data loading speed but increase memory usage
dataloader_num_workers: int = 4
# Whether to pin memory in data loader
# Can improve data transfer speed to GPU but uses more memory
dataloader_pin_memory: bool = True
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}

View file

@ -0,0 +1,176 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.huggingface.config import (
HuggingFacePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
HFFinetuningSingleDevice,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack.schema_utils import webmethod
class TrainingArtifactType(Enum):
CHECKPOINT = "checkpoint"
RESOURCES_STATS = "resources_stats"
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
class HuggingFacePostTrainingImpl:
def __init__(
self,
config: HuggingFacePostTrainingConfig,
datasetio_api: DatasetIO,
datasets: Datasets,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets
self._scheduler = Scheduler()
async def shutdown(self) -> None:
await self._scheduler.shutdown()
@staticmethod
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.CHECKPOINT.value,
name=checkpoint.identifier,
uri=checkpoint.path,
metadata=dict(checkpoint),
)
@staticmethod
def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.RESOURCES_STATS.value,
name=TrainingArtifactType.RESOURCES_STATS.value,
metadata=resources_stats,
)
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
model: str,
checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob:
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
on_log_message_cb("Starting HF finetuning")
recipe = HFFinetuningSingleDevice(
job_uuid=job_uuid,
datasetio_api=self.datasetio_api,
datasets_api=self.datasets_api,
)
resources_allocated, checkpoints = await recipe.train(
model=model,
output_dir=checkpoint_dir,
job_uuid=job_uuid,
lora_config=algorithm_config,
config=training_config,
provider_config=self.config,
)
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
if checkpoints:
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("HF finetuning completed")
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob:
raise NotImplementedError("DPO alignment is not implemented yet")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)
@staticmethod
def _get_artifacts_metadata_by_type(job, artifact_type):
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
@classmethod
def _get_checkpoints(cls, job):
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
@classmethod
def _get_resources_allocated(cls, job):
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
return data[0] if data else None
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
job = self._scheduler.get_job(job_uuid)
match job.status:
# TODO: Add support for other statuses to API
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
status = JobStatus.scheduled
case SchedulerJobStatus.running:
status = JobStatus.in_progress
case SchedulerJobStatus.completed:
status = JobStatus.completed
case SchedulerJobStatus.failed:
status = JobStatus.failed
case _:
raise NotImplementedError()
return PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=status,
scheduled_at=job.scheduled_at,
started_at=job.started_at,
completed_at=job.completed_at,
checkpoints=self._get_checkpoints(job),
resources_allocated=self._get_resources_allocated(job),
)
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None:
self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))

View file

@ -0,0 +1,683 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import gc
import json
import logging
import multiprocessing
import os
import signal
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import psutil
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
# Set tokenizer parallelism environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Force PyTorch to use OpenBLAS instead of MKL
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
os.environ["MKL_NUM_THREADS"] = "1"
import torch
from datasets import Dataset
from peft import LoraConfig
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
)
from trl import SFTConfig, SFTTrainer
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
LoraFinetuningConfig,
TrainingConfig,
)
from ..config import HuggingFacePostTrainingConfig
logger = logging.getLogger(__name__)
def get_gb(to_convert: int) -> str:
"""Converts memory stats to GB and formats to 2 decimal places.
Args:
to_convert: Memory value in bytes
Returns:
str: Memory value in GB formatted to 2 decimal places
"""
return f"{(to_convert / (1024**3)):.2f}"
def get_memory_stats(device: torch.device) -> dict[str, Any]:
"""Get memory statistics for the given device."""
stats = {
"system_memory": {
"total": get_gb(psutil.virtual_memory().total),
"available": get_gb(psutil.virtual_memory().available),
"used": get_gb(psutil.virtual_memory().used),
"percent": psutil.virtual_memory().percent,
}
}
if device.type == "cuda":
stats["device_memory"] = {
"allocated": get_gb(torch.cuda.memory_allocated(device)),
"reserved": get_gb(torch.cuda.memory_reserved(device)),
"max_allocated": get_gb(torch.cuda.max_memory_allocated(device)),
}
elif device.type == "mps":
# MPS doesn't provide direct memory stats, but we can track system memory
stats["device_memory"] = {
"note": "MPS memory stats not directly available",
"system_memory_used": get_gb(psutil.virtual_memory().used),
}
elif device.type == "cpu":
# For CPU, we track process memory usage
process = psutil.Process()
stats["device_memory"] = {
"process_rss": get_gb(process.memory_info().rss),
"process_vms": get_gb(process.memory_info().vms),
"process_percent": process.memory_percent(),
}
return stats
def setup_torch_device(device_str: str) -> torch.device:
"""Initialize and validate a PyTorch device.
This function handles device initialization and validation for different device types:
- CUDA: Validates CUDA availability and handles device selection
- MPS: Validates MPS availability for Apple Silicon
- CPU: Basic validation
- HPU: Raises error as it's not supported
Args:
device_str: String specifying the device ('cuda', 'cpu', 'mps')
Returns:
torch.device: The initialized and validated device
Raises:
RuntimeError: If device initialization fails or device is not supported
"""
try:
device = torch.device(device_str)
except RuntimeError as e:
raise RuntimeError(f"Error getting Torch Device {str(e)}") from e
# Validate device capabilities
if device.type == "cuda":
if not torch.cuda.is_available():
raise RuntimeError(
f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device."
)
if device.index is None:
device = torch.device(device.type, torch.cuda.current_device())
elif device.type == "mps":
if not torch.backends.mps.is_available():
raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.")
elif device.type == "hpu":
raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.")
return device
class HFFinetuningSingleDevice:
def __init__(
self,
job_uuid: str,
datasetio_api: DatasetIO,
datasets_api: Datasets,
):
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.job_uuid = job_uuid
def validate_dataset_format(self, rows: list[dict]) -> bool:
"""Validate that the dataset has the required fields."""
required_fields = ["input_query", "expected_answer", "chat_completion_input"]
return all(field in row for row in rows for field in required_fields)
def _process_instruct_format(self, row: dict) -> tuple[str | None, str | None]:
"""Process a row in instruct format."""
if "chat_completion_input" in row and "expected_answer" in row:
try:
messages = json.loads(row["chat_completion_input"])
if not isinstance(messages, list) or len(messages) != 1:
logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}")
return None, None
if "content" not in messages[0]:
logger.warning(f"Message missing content: {messages[0]}")
return None, None
return messages[0]["content"], row["expected_answer"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}")
return None, None
return None, None
def _process_dialog_format(self, row: dict) -> tuple[str | None, str | None]:
"""Process a row in dialog format."""
if "dialog" in row:
try:
dialog = json.loads(row["dialog"])
if not isinstance(dialog, list) or len(dialog) < 2:
logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}")
return None, None
if dialog[0].get("role") != "user":
logger.warning(f"First message must be from user: {dialog[0]}")
return None, None
if not any(msg.get("role") == "assistant" for msg in dialog):
logger.warning("Dialog must have at least one assistant message")
return None, None
# Convert to human/gpt format
role_map = {"user": "human", "assistant": "gpt"}
conversations = []
for msg in dialog:
if "role" not in msg or "content" not in msg:
logger.warning(f"Message missing role or content: {msg}")
continue
conversations.append({"from": role_map[msg["role"]], "value": msg["content"]})
# Format as a single conversation
return conversations[0]["value"], conversations[1]["value"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse dialog: {row['dialog']}")
return None, None
return None, None
def _process_fallback_format(self, row: dict) -> tuple[str | None, str | None]:
"""Process a row using fallback formats."""
if "input" in row and "output" in row:
return row["input"], row["output"]
elif "prompt" in row and "completion" in row:
return row["prompt"], row["completion"]
elif "question" in row and "answer" in row:
return row["question"], row["answer"]
return None, None
def _format_text(self, input_text: str, output_text: str, provider_config: HuggingFacePostTrainingConfig) -> str:
"""Format input and output text based on model requirements."""
if hasattr(provider_config, "chat_template"):
return provider_config.chat_template.format(input=input_text, output=output_text)
return f"{input_text}\n{output_text}"
def _create_dataset(
self, rows: list[dict], config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig
) -> Dataset:
"""Create and preprocess the dataset."""
formatted_rows = []
for row in rows:
input_text = None
output_text = None
# Process based on format
assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized"
if config.data_config.data_format.value == "instruct":
input_text, output_text = self._process_instruct_format(row)
elif config.data_config.data_format.value == "dialog":
input_text, output_text = self._process_dialog_format(row)
else:
input_text, output_text = self._process_fallback_format(row)
if input_text and output_text:
formatted_text = self._format_text(input_text, output_text, provider_config)
formatted_rows.append({"text": formatted_text})
if not formatted_rows:
assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized"
raise ValueError(
f"No valid input/output pairs found in the dataset for format: {config.data_config.data_format.value}"
)
return Dataset.from_list(formatted_rows)
def _preprocess_dataset(
self, ds: Dataset, tokenizer: AutoTokenizer, provider_config: HuggingFacePostTrainingConfig
) -> Dataset:
"""Preprocess the dataset with tokenizer."""
def tokenize_function(examples):
return tokenizer(
examples["text"],
padding=True,
truncation=True,
max_length=provider_config.max_seq_length,
return_tensors=None,
)
return ds.map(
tokenize_function,
batched=True,
remove_columns=ds.column_names,
)
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider"""
try:
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
def _run_training_sync(
self,
model: str,
provider_config: dict[str, Any],
peft_config: LoraConfig | None,
config: dict[str, Any],
output_dir_path: Path | None,
) -> None:
"""Synchronous wrapper for running training process.
This method serves as a bridge between the multiprocessing Process and the async training function.
It creates a new event loop to run the async training process.
Args:
model: The model identifier to load
dataset_id: ID of the dataset to use for training
provider_config: Configuration specific to the HuggingFace provider
peft_config: Optional LoRA configuration
config: General training configuration
output_dir_path: Optional path to save the model
"""
import asyncio
logger.info("Starting training process with async wrapper")
asyncio.run(
self._run_training(
model=model,
provider_config=provider_config,
peft_config=peft_config,
config=config,
output_dir_path=output_dir_path,
)
)
async def load_dataset(
self,
model: str,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[Dataset, Dataset, AutoTokenizer]:
"""Load and prepare the dataset for training.
Args:
model: The model identifier to load
config: Training configuration
provider_config: Provider-specific configuration
Returns:
tuple: (train_dataset, eval_dataset, tokenizer)
"""
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await self._setup_data(config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
logger.info(f"Loaded {len(rows)} rows from dataset")
# Initialize tokenizer
logger.info(f"Initializing tokenizer for model: {model}")
try:
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
# Set pad token to eos token if not present
# This is common for models that don't have a dedicated pad token
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
# Set padding side to right for causal language modeling
# This ensures that padding tokens don't interfere with the model's ability
# to predict the next token in the sequence
tokenizer.padding_side = "right"
# Set truncation side to right to keep the beginning of the sequence
# This is important for maintaining context and instruction format
tokenizer.truncation_side = "right"
# Set model max length to match provider config
# This ensures consistent sequence lengths across the training process
tokenizer.model_max_length = provider_config.max_seq_length
logger.info("Tokenizer initialized successfully")
except Exception as e:
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
# Create and preprocess dataset
logger.info("Creating and preprocessing dataset")
try:
ds = self._create_dataset(rows, config, provider_config)
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
logger.info(f"Dataset created with {len(ds)} examples")
except Exception as e:
raise ValueError(f"Failed to create dataset: {str(e)}") from e
# Split dataset
logger.info("Splitting dataset into train and validation sets")
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
train_dataset = train_val_split["train"]
eval_dataset = train_val_split["test"]
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
return train_dataset, eval_dataset, tokenizer
def load_model(
self,
model: str,
device: torch.device,
provider_config: HuggingFacePostTrainingConfig,
) -> AutoModelForCausalLM:
"""Load and initialize the model for training.
Args:
model: The model identifier to load
device: The device to load the model onto
provider_config: Provider-specific configuration
Returns:
The loaded and initialized model
Raises:
RuntimeError: If model loading fails
"""
logger.info("Loading the base model")
try:
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
model_obj = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype="auto" if device.type != "cpu" else "float32",
quantization_config=None,
config=model_config,
**provider_config.model_specific_config,
)
# Always move model to specified device
model_obj = model_obj.to(device)
logger.info(f"Model loaded and moved to device: {model_obj.device}")
return model_obj
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}") from e
def setup_training_args(
self,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
device: torch.device,
output_dir_path: Path | None,
steps_per_epoch: int,
) -> SFTConfig:
"""Setup training arguments.
Args:
config: Training configuration
provider_config: Provider-specific configuration
device: The device to train on
output_dir_path: Optional path to save the model
steps_per_epoch: Number of steps per epoch
Returns:
Configured SFTConfig object
"""
logger.info("Configuring training arguments")
lr = 2e-5
if config.optimizer_config:
lr = config.optimizer_config.lr
logger.info(f"Using custom learning rate: {lr}")
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
data_config = config.data_config
# Calculate steps
total_steps = steps_per_epoch * config.n_epochs
max_steps = min(config.max_steps_per_epoch, total_steps)
eval_steps = max(1, steps_per_epoch // 10) # Evaluate 10 times per epoch
save_steps = max(1, steps_per_epoch // 5) # Save 5 times per epoch
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
logger.info("Training configuration:")
logger.info(f"- Steps per epoch: {steps_per_epoch}")
logger.info(f"- Total steps: {total_steps}")
logger.info(f"- Max steps: {max_steps}")
logger.info(f"- Eval steps: {eval_steps}")
logger.info(f"- Save steps: {save_steps}")
logger.info(f"- Logging steps: {logging_steps}")
# Configure save strategy
save_strategy = "no"
if output_dir_path:
save_strategy = "steps"
logger.info(f"Will save checkpoints to {output_dir_path}")
return SFTConfig(
max_steps=max_steps,
output_dir=str(output_dir_path) if output_dir_path is not None else None,
num_train_epochs=config.n_epochs,
per_device_train_batch_size=data_config.batch_size,
fp16=device.type == "cuda",
bf16=False, # Causes CPU issues.
eval_strategy="steps",
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
save_strategy=save_strategy,
report_to="none",
max_seq_length=provider_config.max_seq_length,
gradient_accumulation_steps=config.gradient_accumulation_steps,
gradient_checkpointing=provider_config.gradient_checkpointing,
learning_rate=lr,
warmup_ratio=provider_config.warmup_ratio,
weight_decay=provider_config.weight_decay,
remove_unused_columns=False,
dataloader_pin_memory=provider_config.dataloader_pin_memory,
dataloader_num_workers=provider_config.dataloader_num_workers,
dataset_text_field="text",
packing=False,
load_best_model_at_end=True if output_dir_path else False,
metric_for_best_model="eval_loss",
greater_is_better=False,
eval_steps=eval_steps,
save_steps=save_steps,
logging_steps=logging_steps,
)
def save_model(
self,
model_obj: AutoModelForCausalLM,
trainer: SFTTrainer,
peft_config: LoraConfig | None,
output_dir_path: Path,
) -> None:
"""Save the trained model.
Args:
model_obj: The model to save
trainer: The trainer instance
peft_config: Optional LoRA configuration
output_dir_path: Path to save the model
"""
logger.info("Saving final model")
model_obj.config.use_cache = True
if peft_config:
logger.info("Merging LoRA weights with base model")
model_obj = trainer.model.merge_and_unload()
else:
model_obj = trainer.model
save_path = output_dir_path / "merged_model"
logger.info(f"Saving model to {save_path}")
model_obj.save_pretrained(save_path)
async def _run_training(
self,
model: str,
provider_config: dict[str, Any],
peft_config: LoraConfig | None,
config: dict[str, Any],
output_dir_path: Path | None,
) -> None:
"""Run the training process with signal handling."""
def signal_handler(signum, frame):
"""Handle termination signals gracefully."""
logger.info(f"Received signal {signum}, initiating graceful shutdown")
sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Convert config dicts back to objects
logger.info("Initializing configuration objects")
provider_config_obj = HuggingFacePostTrainingConfig(**provider_config)
config_obj = TrainingConfig(**config)
# Initialize and validate device
device = setup_torch_device(provider_config_obj.device)
logger.info(f"Using device '{device}'")
# Load dataset and tokenizer
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
# Calculate steps per epoch
if not config_obj.data_config:
raise ValueError("DataConfig is required for training")
steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size
# Setup training arguments
training_args = self.setup_training_args(
config_obj,
provider_config_obj,
device,
output_dir_path,
steps_per_epoch,
)
# Load model
model_obj = self.load_model(model, device, provider_config_obj)
# Initialize trainer
logger.info("Initializing SFTTrainer")
trainer = SFTTrainer(
model=model_obj,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
args=training_args,
)
try:
# Train
logger.info("Starting training")
trainer.train()
logger.info("Training completed successfully")
# Save final model if output directory is provided
if output_dir_path:
self.save_model(model_obj, trainer, peft_config, output_dir_path)
finally:
# Clean up resources
logger.info("Cleaning up resources")
if hasattr(trainer, "model"):
evacuate_model_from_device(trainer.model, device.type)
del trainer
gc.collect()
logger.info("Cleanup completed")
async def train(
self,
model: str,
output_dir: str | None,
job_uuid: str,
lora_config: LoraFinetuningConfig,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[dict[str, Any], list[Checkpoint] | None]:
"""Train a model using HuggingFace's SFTTrainer"""
# Initialize and validate device
device = setup_torch_device(provider_config.device)
logger.info(f"Using device '{device}'")
output_dir_path = None
if output_dir:
output_dir_path = Path(output_dir)
# Track memory stats
memory_stats = {
"initial": get_memory_stats(device),
"after_training": None,
"final": None,
}
# Configure LoRA
peft_config = None
if lora_config:
peft_config = LoraConfig(
lora_alpha=lora_config.alpha,
lora_dropout=0.1,
r=lora_config.rank,
bias="none",
task_type="CAUSAL_LM",
target_modules=lora_config.lora_attn_modules,
)
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
# Train in a separate process
logger.info("Starting training in separate process")
try:
# Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)
process = multiprocessing.Process(
target=self._run_training_sync,
kwargs={
"model": model,
"provider_config": provider_config.model_dump(),
"peft_config": peft_config,
"config": config.model_dump(),
"output_dir_path": output_dir_path,
},
)
process.start()
# Monitor the process
while process.is_alive():
process.join(timeout=1) # Check every second
if not process.is_alive():
break
# Get the return code
if process.exitcode != 0:
raise RuntimeError(f"Training failed with exit code {process.exitcode}")
memory_stats["after_training"] = get_memory_stats(device)
checkpoints = None
if output_dir_path:
# Create checkpoint
checkpoint = Checkpoint(
identifier=f"{model}-sft-{config.n_epochs}",
created_at=datetime.now(timezone.utc),
epoch=config.n_epochs,
post_training_job_id=job_uuid,
path=str(output_dir_path / "merged_model"),
)
checkpoints = [checkpoint]
return memory_stats, checkpoints
finally:
memory_stats["final"] = get_memory_stats(device)
gc.collect()

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import gc
import logging import logging
import os import os
import time import time
@ -47,6 +46,7 @@ from llama_stack.apis.post_training import (
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from llama_stack.providers.inline.post_training.torchtune.common import utils from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer, TorchtuneCheckpointer,
@ -554,11 +554,7 @@ class LoraFinetuningSingleDevice:
checkpoints.append(checkpoint) checkpoints.append(checkpoint)
# clean up the memory after training finishes # clean up the memory after training finishes
if self._device.type != "cpu": evacuate_model_from_device(self._model, self._device.type)
self._model.to("cpu")
torch.cuda.empty_cache()
del self._model
gc.collect()
return (memory_stats, checkpoints) return (memory_stats, checkpoints)

View file

@ -21,6 +21,17 @@ def available_providers() -> list[ProviderSpec]:
Api.datasets, Api.datasets,
], ],
), ),
InlineProviderSpec(
api=Api.post_training,
provider_type="inline::huggingface",
pip_packages=["torch", "trl", "transformers", "peft", "datasets"],
module="llama_stack.providers.inline.post_training.huggingface",
config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
),
remote_provider_spec( remote_provider_spec(
api=Api.post_training, api=Api.post_training,
adapter=AdapterSpec( adapter=AdapterSpec(

View file

@ -441,6 +441,7 @@
"opentelemetry-exporter-otlp-proto-http", "opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk", "opentelemetry-sdk",
"pandas", "pandas",
"peft",
"pillow", "pillow",
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
@ -451,9 +452,11 @@
"scikit-learn", "scikit-learn",
"scipy", "scipy",
"sentencepiece", "sentencepiece",
"torch",
"tqdm", "tqdm",
"transformers", "transformers",
"tree_sitter", "tree_sitter",
"trl",
"uvicorn" "uvicorn"
], ],
"open-benchmark": [ "open-benchmark": [

View file

@ -13,9 +13,10 @@ distribution_spec:
- inline::basic - inline::basic
- inline::braintrust - inline::braintrust
post_training: post_training:
- inline::torchtune - inline::huggingface
datasetio: datasetio:
- inline::localfs - inline::localfs
- remote::huggingface
telemetry: telemetry:
- inline::meta-reference - inline::meta-reference
agents: agents:

View file

@ -49,16 +49,24 @@ providers:
type: sqlite type: sqlite
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/localfs_datasetio.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/localfs_datasetio.db
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/huggingface}/huggingface_datasetio.db
telemetry: telemetry:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: {} config: {}
post_training: post_training:
- provider_id: torchtune-post-training - provider_id: huggingface
provider_type: inline::torchtune provider_type: inline::huggingface
config: { config:
checkpoint_format: huggingface checkpoint_format: huggingface
} distributed_backend: null
device: cpu
agents: agents:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference

View file

@ -23,6 +23,8 @@ distribution_spec:
- inline::basic - inline::basic
- inline::llm-as-judge - inline::llm-as-judge
- inline::braintrust - inline::braintrust
post_training:
- inline::huggingface
tool_runtime: tool_runtime:
- remote::brave-search - remote::brave-search
- remote::tavily-search - remote::tavily-search

View file

@ -13,6 +13,7 @@ from llama_stack.distribution.datatypes import (
ShieldInput, ShieldInput,
ToolGroupInput, ToolGroupInput,
) )
from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -28,6 +29,7 @@ def get_distribution_template() -> DistributionTemplate:
"eval": ["inline::meta-reference"], "eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"], "datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"post_training": ["inline::huggingface"],
"tool_runtime": [ "tool_runtime": [
"remote::brave-search", "remote::brave-search",
"remote::tavily-search", "remote::tavily-search",
@ -47,7 +49,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="inline::faiss", provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
) )
posttraining_provider = Provider(
provider_id="huggingface",
provider_type="inline::huggingface",
config=HuggingFacePostTrainingConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
inference_model = ModelInput( inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}", model_id="${env.INFERENCE_MODEL}",
provider_id="ollama", provider_id="ollama",
@ -92,6 +98,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_overrides={ provider_overrides={
"inference": [inference_provider], "inference": [inference_provider],
"vector_io": [vector_io_provider_faiss], "vector_io": [vector_io_provider_faiss],
"post_training": [posttraining_provider],
}, },
default_models=[inference_model, embedding_model], default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups, default_tool_groups=default_tool_groups,
@ -100,6 +107,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_overrides={ provider_overrides={
"inference": [inference_provider], "inference": [inference_provider],
"vector_io": [vector_io_provider_faiss], "vector_io": [vector_io_provider_faiss],
"post_training": [posttraining_provider],
"safety": [ "safety": [
Provider( Provider(
provider_id="llama-guard", provider_id="llama-guard",

View file

@ -5,6 +5,7 @@ apis:
- datasetio - datasetio
- eval - eval
- inference - inference
- post_training
- safety - safety
- scoring - scoring
- telemetry - telemetry
@ -80,6 +81,13 @@ providers:
provider_type: inline::braintrust provider_type: inline::braintrust
config: config:
openai_api_key: ${env.OPENAI_API_KEY:} openai_api_key: ${env.OPENAI_API_KEY:}
post_training:
- provider_id: huggingface
provider_type: inline::huggingface
config:
checkpoint_format: huggingface
distributed_backend: null
device: cpu
tool_runtime: tool_runtime:
- provider_id: brave-search - provider_id: brave-search
provider_type: remote::brave-search provider_type: remote::brave-search

View file

@ -5,6 +5,7 @@ apis:
- datasetio - datasetio
- eval - eval
- inference - inference
- post_training
- safety - safety
- scoring - scoring
- telemetry - telemetry
@ -78,6 +79,13 @@ providers:
provider_type: inline::braintrust provider_type: inline::braintrust
config: config:
openai_api_key: ${env.OPENAI_API_KEY:} openai_api_key: ${env.OPENAI_API_KEY:}
post_training:
- provider_id: huggingface
provider_type: inline::huggingface
config:
checkpoint_format: huggingface
distributed_backend: null
device: cpu
tool_runtime: tool_runtime:
- provider_id: brave-search - provider_id: brave-search
provider_type: remote::brave-search provider_type: remote::brave-search

View file

@ -45,6 +45,7 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
dev = [ dev = [
"pytest", "pytest",
"pytest-timeout",
"pytest-asyncio", "pytest-asyncio",
"pytest-cov", "pytest-cov",
"pytest-html", "pytest-html",

View file

@ -1,206 +1,69 @@
# This file was autogenerated by uv via the following command: # This file was autogenerated by uv via the following command:
# uv export --frozen --no-hashes --no-emit-project --output-file=requirements.txt # uv export --frozen --no-hashes --no-emit-project --output-file=requirements.txt
annotated-types==0.7.0 annotated-types==0.7.0
# via pydantic
anyio==4.8.0 anyio==4.8.0
# via
# httpx
# llama-stack-client
# openai
attrs==25.1.0 attrs==25.1.0
# via
# jsonschema
# referencing
blobfile==3.0.0 blobfile==3.0.0
# via llama-stack
cachetools==5.5.2 cachetools==5.5.2
# via google-auth
certifi==2025.1.31 certifi==2025.1.31
# via
# httpcore
# httpx
# kubernetes
# requests
charset-normalizer==3.4.1 charset-normalizer==3.4.1
# via requests
click==8.1.8 click==8.1.8
# via llama-stack-client
colorama==0.4.6 ; sys_platform == 'win32' colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
distro==1.9.0 distro==1.9.0
# via
# llama-stack-client
# openai
durationpy==0.9 durationpy==0.9
# via kubernetes
exceptiongroup==1.2.2 ; python_full_version < '3.11' exceptiongroup==1.2.2 ; python_full_version < '3.11'
# via anyio
filelock==3.17.0 filelock==3.17.0
# via
# blobfile
# huggingface-hub
fire==0.7.0 fire==0.7.0
# via llama-stack
fsspec==2024.12.0 fsspec==2024.12.0
# via huggingface-hub
google-auth==2.38.0 google-auth==2.38.0
# via kubernetes
h11==0.16.0 h11==0.16.0
# via
# httpcore
# llama-stack
httpcore==1.0.9 httpcore==1.0.9
# via httpx
httpx==0.28.1 httpx==0.28.1
# via
# llama-stack
# llama-stack-client
# openai
huggingface-hub==0.29.0 huggingface-hub==0.29.0
# via llama-stack
idna==3.10 idna==3.10
# via
# anyio
# httpx
# requests
jinja2==3.1.6 jinja2==3.1.6
# via llama-stack
jiter==0.8.2 jiter==0.8.2
# via openai
jsonschema==4.23.0 jsonschema==4.23.0
# via llama-stack
jsonschema-specifications==2024.10.1 jsonschema-specifications==2024.10.1
# via jsonschema
kubernetes==32.0.1 kubernetes==32.0.1
# via llama-stack
llama-stack-client==0.2.7 llama-stack-client==0.2.7
# via llama-stack
lxml==5.3.1 lxml==5.3.1
# via blobfile
markdown-it-py==3.0.0 markdown-it-py==3.0.0
# via rich
markupsafe==3.0.2 markupsafe==3.0.2
# via jinja2
mdurl==0.1.2 mdurl==0.1.2
# via markdown-it-py
numpy==2.2.3 numpy==2.2.3
# via pandas
oauthlib==3.2.2 oauthlib==3.2.2
# via
# kubernetes
# requests-oauthlib
openai==1.71.0 openai==1.71.0
# via llama-stack
packaging==24.2 packaging==24.2
# via huggingface-hub
pandas==2.2.3 pandas==2.2.3
# via llama-stack-client
pillow==11.1.0 pillow==11.1.0
# via llama-stack
prompt-toolkit==3.0.50 prompt-toolkit==3.0.50
# via
# llama-stack
# llama-stack-client
pyaml==25.1.0 pyaml==25.1.0
# via llama-stack-client
pyasn1==0.6.1 pyasn1==0.6.1
# via
# pyasn1-modules
# rsa
pyasn1-modules==0.4.2 pyasn1-modules==0.4.2
# via google-auth
pycryptodomex==3.21.0 pycryptodomex==3.21.0
# via blobfile
pydantic==2.10.6 pydantic==2.10.6
# via
# llama-stack
# llama-stack-client
# openai
pydantic-core==2.27.2 pydantic-core==2.27.2
# via pydantic
pygments==2.19.1 pygments==2.19.1
# via rich
python-dateutil==2.9.0.post0 python-dateutil==2.9.0.post0
# via
# kubernetes
# pandas
python-dotenv==1.0.1 python-dotenv==1.0.1
# via llama-stack
pytz==2025.1 pytz==2025.1
# via pandas
pyyaml==6.0.2 pyyaml==6.0.2
# via
# huggingface-hub
# kubernetes
# pyaml
referencing==0.36.2 referencing==0.36.2
# via
# jsonschema
# jsonschema-specifications
regex==2024.11.6 regex==2024.11.6
# via tiktoken
requests==2.32.3 requests==2.32.3
# via
# huggingface-hub
# kubernetes
# llama-stack
# requests-oauthlib
# tiktoken
requests-oauthlib==2.0.0 requests-oauthlib==2.0.0
# via kubernetes
rich==13.9.4 rich==13.9.4
# via
# llama-stack
# llama-stack-client
rpds-py==0.22.3 rpds-py==0.22.3
# via
# jsonschema
# referencing
rsa==4.9 rsa==4.9
# via google-auth
setuptools==75.8.0 setuptools==75.8.0
# via llama-stack
six==1.17.0 six==1.17.0
# via
# kubernetes
# python-dateutil
sniffio==1.3.1 sniffio==1.3.1
# via
# anyio
# llama-stack-client
# openai
termcolor==2.5.0 termcolor==2.5.0
# via
# fire
# llama-stack
# llama-stack-client
tiktoken==0.9.0 tiktoken==0.9.0
# via llama-stack
tqdm==4.67.1 tqdm==4.67.1
# via
# huggingface-hub
# llama-stack-client
# openai
typing-extensions==4.12.2 typing-extensions==4.12.2
# via
# anyio
# huggingface-hub
# llama-stack-client
# openai
# pydantic
# pydantic-core
# referencing
# rich
tzdata==2025.1 tzdata==2025.1
# via pandas
urllib3==2.3.0 urllib3==2.3.0
# via
# blobfile
# kubernetes
# requests
wcwidth==0.2.13 wcwidth==0.2.13
# via prompt-toolkit
websocket-client==1.8.0 websocket-client==1.8.0
# via kubernetes

View file

@ -4,20 +4,38 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import sys
import time
import uuid
import pytest import pytest
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.post_training import ( from llama_stack.apis.post_training import (
Checkpoint,
DataConfig, DataConfig,
LoraFinetuningConfig, LoraFinetuningConfig,
OptimizerConfig,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig, TrainingConfig,
) )
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
logger = logging.getLogger(__name__)
@pytest.fixture(autouse=True)
def capture_output(capsys):
"""Fixture to capture and display output during test execution."""
yield
captured = capsys.readouterr()
if captured.out:
print("\nCaptured stdout:", captured.out)
if captured.err:
print("\nCaptured stderr:", captured.err)
# Force flush stdout to see prints immediately
sys.stdout.reconfigure(line_buffering=True)
# How to run this test: # How to run this test:
# #
# pytest llama_stack/providers/tests/post_training/test_post_training.py # pytest llama_stack/providers/tests/post_training/test_post_training.py
@ -25,10 +43,31 @@ from llama_stack.apis.post_training import (
# -v -s --tb=short --disable-warnings # -v -s --tb=short --disable-warnings
@pytest.mark.skip(reason="FIXME FIXME @yanxi0830 this needs to be migrated to use the API")
class TestPostTraining: class TestPostTraining:
@pytest.mark.asyncio @pytest.mark.integration
async def test_supervised_fine_tune(self, post_training_stack): @pytest.mark.parametrize(
"purpose, source",
[
(
"post-training/messages",
{
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
),
],
)
@pytest.mark.timeout(360) # 6 minutes timeout
def test_supervised_fine_tune(self, llama_stack_client, purpose, source):
logger.info("Starting supervised fine-tuning test")
# register dataset to train
dataset = llama_stack_client.datasets.register(
purpose=purpose,
source=source,
)
logger.info(f"Registered dataset with ID: {dataset.identifier}")
algorithm_config = LoraFinetuningConfig( algorithm_config = LoraFinetuningConfig(
type="LoRA", type="LoRA",
lora_attn_modules=["q_proj", "v_proj", "output_proj"], lora_attn_modules=["q_proj", "v_proj", "output_proj"],
@ -39,62 +78,74 @@ class TestPostTraining:
) )
data_config = DataConfig( data_config = DataConfig(
dataset_id="alpaca", dataset_id=dataset.identifier,
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
data_format="instruct",
) )
optimizer_config = OptimizerConfig( # setup training config with minimal settings
optimizer_type="adamw",
lr=3e-4,
lr_min=3e-5,
weight_decay=0.1,
num_warmup_steps=100,
)
training_config = TrainingConfig( training_config = TrainingConfig(
n_epochs=1, n_epochs=1,
data_config=data_config, data_config=data_config,
optimizer_config=optimizer_config,
max_steps_per_epoch=1, max_steps_per_epoch=1,
gradient_accumulation_steps=1, gradient_accumulation_steps=1,
) )
post_training_impl = post_training_stack
response = await post_training_impl.supervised_fine_tune( job_uuid = f"test-job{uuid.uuid4()}"
job_uuid="1234", logger.info(f"Starting training job with UUID: {job_uuid}")
model="Llama3.2-3B-Instruct",
# train with HF trl SFTTrainer as the default
_ = llama_stack_client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
model="ibm-granite/granite-3.3-2b-instruct",
algorithm_config=algorithm_config, algorithm_config=algorithm_config,
training_config=training_config, training_config=training_config,
hyperparam_search_config={}, hyperparam_search_config={},
logger_config={}, logger_config={},
checkpoint_dir="null", checkpoint_dir=None,
) )
assert isinstance(response, PostTrainingJob)
assert response.job_uuid == "1234"
@pytest.mark.asyncio while True:
async def test_get_training_jobs(self, post_training_stack): status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
post_training_impl = post_training_stack if not status:
jobs_list = await post_training_impl.get_training_jobs() logger.error("Job not found")
assert isinstance(jobs_list, list) break
assert jobs_list[0].job_uuid == "1234"
@pytest.mark.asyncio logger.info(f"Current status: {status}")
async def test_get_training_job_status(self, post_training_stack): if status.status == "completed":
post_training_impl = post_training_stack break
job_status = await post_training_impl.get_training_job_status("1234")
assert isinstance(job_status, PostTrainingJobStatusResponse)
assert job_status.job_uuid == "1234"
assert job_status.status == JobStatus.completed
assert isinstance(job_status.checkpoints[0], Checkpoint)
@pytest.mark.asyncio logger.info("Waiting for job to complete...")
async def test_get_training_job_artifacts(self, post_training_stack): time.sleep(10) # Increased sleep time to reduce polling frequency
post_training_impl = post_training_stack
job_artifacts = await post_training_impl.get_training_job_artifacts("1234") artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse) logger.info(f"Job artifacts: {artifacts}")
assert job_artifacts.job_uuid == "1234"
assert isinstance(job_artifacts.checkpoints[0], Checkpoint) # TODO: Fix these tests to properly represent the Jobs API in training
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0" # @pytest.mark.asyncio
assert job_artifacts.checkpoints[0].epoch == 0 # async def test_get_training_jobs(self, post_training_stack):
assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path # post_training_impl = post_training_stack
# jobs_list = await post_training_impl.get_training_jobs()
# assert isinstance(jobs_list, list)
# assert jobs_list[0].job_uuid == "1234"
# @pytest.mark.asyncio
# async def test_get_training_job_status(self, post_training_stack):
# post_training_impl = post_training_stack
# job_status = await post_training_impl.get_training_job_status("1234")
# assert isinstance(job_status, PostTrainingJobStatusResponse)
# assert job_status.job_uuid == "1234"
# assert job_status.status == JobStatus.completed
# assert isinstance(job_status.checkpoints[0], Checkpoint)
# @pytest.mark.asyncio
# async def test_get_training_job_artifacts(self, post_training_stack):
# post_training_impl = post_training_stack
# job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
# assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
# assert job_artifacts.job_uuid == "1234"
# assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
# assert job_artifacts.checkpoints[0].identifier == "instructlab/granite-7b-lab"
# assert job_artifacts.checkpoints[0].epoch == 0
# assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path

14
uv.lock generated
View file

@ -1459,6 +1459,7 @@ dev = [
{ name = "pytest-cov" }, { name = "pytest-cov" },
{ name = "pytest-html" }, { name = "pytest-html" },
{ name = "pytest-json-report" }, { name = "pytest-json-report" },
{ name = "pytest-timeout" },
{ name = "ruamel-yaml" }, { name = "ruamel-yaml" },
{ name = "ruff" }, { name = "ruff" },
{ name = "types-requests" }, { name = "types-requests" },
@ -1557,6 +1558,7 @@ requires-dist = [
{ name = "pytest-cov", marker = "extra == 'dev'" }, { name = "pytest-cov", marker = "extra == 'dev'" },
{ name = "pytest-html", marker = "extra == 'dev'" }, { name = "pytest-html", marker = "extra == 'dev'" },
{ name = "pytest-json-report", marker = "extra == 'dev'" }, { name = "pytest-json-report", marker = "extra == 'dev'" },
{ name = "pytest-timeout", marker = "extra == 'dev'" },
{ name = "python-dotenv" }, { name = "python-dotenv" },
{ name = "qdrant-client", marker = "extra == 'unit'" }, { name = "qdrant-client", marker = "extra == 'unit'" },
{ name = "requests" }, { name = "requests" },
@ -2852,6 +2854,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3e/43/7e7b2ec865caa92f67b8f0e9231a798d102724ca4c0e1f414316be1c1ef2/pytest_metadata-3.1.1-py3-none-any.whl", hash = "sha256:c8e0844db684ee1c798cfa38908d20d67d0463ecb6137c72e91f418558dd5f4b", size = 11428, upload-time = "2024-02-12T19:38:42.531Z" }, { url = "https://files.pythonhosted.org/packages/3e/43/7e7b2ec865caa92f67b8f0e9231a798d102724ca4c0e1f414316be1c1ef2/pytest_metadata-3.1.1-py3-none-any.whl", hash = "sha256:c8e0844db684ee1c798cfa38908d20d67d0463ecb6137c72e91f418558dd5f4b", size = 11428, upload-time = "2024-02-12T19:38:42.531Z" },
] ]
[[package]]
name = "pytest-timeout"
version = "2.4.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pytest" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ac/82/4c9ecabab13363e72d880f2fb504c5f750433b2b6f16e99f4ec21ada284c/pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a", size = 17973, upload-time = "2025-05-05T19:44:34.99Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382, upload-time = "2025-05-05T19:44:33.502Z" },
]
[[package]] [[package]]
name = "python-dateutil" name = "python-dateutil"
version = "2.9.0.post0" version = "2.9.0.post0"