llama-stack-mirror/docs/docs/advanced_apis/post_training.mdx
ehhuang a3f5072776
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
Installer CI / lint (push) Failing after 2s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Installer CI / smoke-test-on-dev (push) Failing after 2s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 2s
Test Llama Stack Build / build-single-provider (push) Failing after 4s
Python Package Build Test / build (3.12) (push) Failing after 2s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 3s
Python Package Build Test / build (3.13) (push) Failing after 1s
API Conformance Tests / check-schema-compatibility (push) Successful in 10s
Unit Tests / unit-tests (3.12) (push) Failing after 3s
Test Llama Stack Build / build (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 3s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
UI Tests / ui-tests (22) (push) Successful in 40s
Pre-commit / pre-commit (push) Successful in 1m18s
chore!: remove --env from llama stack run (#3711)
# What does this PR do?
user can simply set env vars in the beginning of the command.`FOO=BAR
llama stack run ...`

## Test Plan
Run
TELEMETRY_SINKS=coneol uv run --with llama-stack llama stack build
--distro=starter --image-type=venv --run




---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/llamastack/llama-stack/pull/3711).
* #3714
* __->__ #3711
2025-10-07 20:58:15 -07:00

305 lines
8.8 KiB
Text

# Post-Training
Post-training in Llama Stack allows you to fine-tune models using various providers and frameworks. This section covers all available post-training providers and how to use them effectively.
## Overview
Llama Stack provides multiple post-training providers:
- **HuggingFace SFTTrainer** (`inline::huggingface`) - Fine-tuning using HuggingFace ecosystem
- **TorchTune** (`inline::torchtune`) - Fine-tuning using Meta's TorchTune framework
- **NVIDIA** (`remote::nvidia`) - Fine-tuning using NVIDIA's platform
## HuggingFace SFTTrainer
[HuggingFace SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets.
### Features
- Simple access through the post_training API
- Fully integrated with Llama Stack
- GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders)
### Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `device` | `str` | No | cuda | |
| `distributed_backend` | `Literal['fsdp', 'deepspeed']` | No | | |
| `checkpoint_format` | `Literal['full_state', 'huggingface']` | No | huggingface | |
| `chat_template` | `str` | No | |
| `model_specific_config` | `dict` | No | `{'trust_remote_code': True, 'attn_implementation': 'sdpa'}` | |
| `max_seq_length` | `int` | No | 2048 | |
| `gradient_checkpointing` | `bool` | No | False | |
| `save_total_limit` | `int` | No | 3 | |
| `logging_steps` | `int` | No | 10 | |
| `warmup_ratio` | `float` | No | 0.1 | |
| `weight_decay` | `float` | No | 0.01 | |
| `dataloader_num_workers` | `int` | No | 4 | |
| `dataloader_pin_memory` | `bool` | No | True | |
### Sample Configuration
```yaml
checkpoint_format: huggingface
distributed_backend: null
device: cpu
```
### Setup
You can access the HuggingFace trainer via the `starter` distribution:
```bash
llama stack build --distro starter --image-type venv
llama stack run ~/.llama/distributions/starter/starter-run.yaml
```
### Usage Example
```python
import time
import uuid
from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")
client = create_http_client()
# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=32,
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
gradient_accumulation_steps=1,
max_steps_per_epoch=0,
max_validation_steps=1,
n_epochs=4,
)
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
alpha=1,
apply_lora_to_mlp=True,
apply_lora_to_output=False,
lora_attn_modules=["q_proj"],
rank=1,
type="LoRA",
)
job_uuid = f"test-job{uuid.uuid4()}"
# Example Model
training_model = "ibm-granite/granite-3.3-8b-instruct"
start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)
# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break
print(status)
if status.status == "completed":
break
print("Waiting for job to complete...")
time.sleep(5)
end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")
print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))
```
## TorchTune
[TorchTune](https://github.com/pytorch/torchtune) is an inline post training provider for Llama Stack. It provides a simple and efficient way to fine-tune language models using PyTorch.
### Features
- Simple access through the post_training API
- Fully integrated with Llama Stack
- GPU support and single device capabilities
- Support for LoRA
### Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `torch_seed` | `int \| None` | No | | |
| `checkpoint_format` | `Literal['meta', 'huggingface']` | No | meta | |
### Sample Configuration
```yaml
checkpoint_format: meta
```
### Setup
You can access the TorchTune trainer by writing your own yaml pointing to the provider:
```yaml
post_training:
- provider_id: torchtune
provider_type: inline::torchtune
config: {}
```
You can then build and run your own stack with this provider.
### Usage Example
```python
import time
import uuid
from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")
client = create_http_client()
# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=32,
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
gradient_accumulation_steps=1,
max_steps_per_epoch=0,
max_validation_steps=1,
n_epochs=4,
)
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
alpha=1,
apply_lora_to_mlp=True,
apply_lora_to_output=False,
lora_attn_modules=["q_proj"],
rank=1,
type="LoRA",
)
job_uuid = f"test-job{uuid.uuid4()}"
# Example Model
training_model = "meta-llama/Llama-2-7b-hf"
start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)
# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break
print(status)
if status.status == "completed":
break
print("Waiting for job to complete...")
time.sleep(5)
end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")
print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))
```
## NVIDIA
NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.
### Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `api_key` | `str \| None` | No | | The NVIDIA API key. |
| `dataset_namespace` | `str \| None` | No | default | The NVIDIA dataset namespace. |
| `project_id` | `str \| None` | No | test-example-model@v1 | The NVIDIA project ID. |
| `customizer_url` | `str \| None` | No | | Base URL for the NeMo Customizer API |
| `timeout` | `int` | No | 300 | Timeout for the NVIDIA Post Training API |
| `max_retries` | `int` | No | 3 | Maximum number of retries for the NVIDIA Post Training API |
| `output_model_dir` | `str` | No | test-example-model@v1 | Directory to save the output model |
### Sample Configuration
```yaml
api_key: ${env.NVIDIA_API_KEY:=}
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}
```
## Best Practices
- **Choose the right provider**: Use HuggingFace for broader compatibility, TorchTune for Meta models, or NVIDIA for their ecosystem
- **Configure hardware appropriately**: Ensure your configuration matches your available hardware (CPU, GPU, MPS)
- **Monitor jobs**: Always monitor job status and handle completion appropriately
- **Use appropriate datasets**: Ensure your dataset format matches the expected input format for your chosen provider
## Next Steps
- Check out the [Building Applications - Fine-tuning](../building_applications/index.mdx) guide for application-level examples
- See the [Providers](../providers/post_training/index.mdx) section for detailed provider documentation
- Review the [API Reference](../advanced_apis/post_training.mdx) for complete API documentation