llama-stack-mirror/docs/docs/advanced_apis/post_training.mdx
Alexey Rybak 6101c8e015
docs: fix broken links (#3540)
# What does this PR do?

<!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. -->

<!-- If resolving an issue, uncomment and update the line below -->

<!-- Closes #[issue-number] -->

- Fixes broken links and Docusaurus search

Closes #3518

## Test Plan

The following should produce a clean build with no warnings and search enabled:

```
npm install
npm run gen-api-docs all
npm run build
npm run serve
```

<!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* -->
2025-09-24 14:16:31 -07:00

305 lines
8.8 KiB
Text

# Post-Training
Post-training in Llama Stack allows you to fine-tune models using various providers and frameworks. This section covers all available post-training providers and how to use them effectively.
## Overview
Llama Stack provides multiple post-training providers:
- **HuggingFace SFTTrainer** (`inline::huggingface`) - Fine-tuning using HuggingFace ecosystem
- **TorchTune** (`inline::torchtune`) - Fine-tuning using Meta's TorchTune framework
- **NVIDIA** (`remote::nvidia`) - Fine-tuning using NVIDIA's platform
## HuggingFace SFTTrainer
[HuggingFace SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets.
### Features
- Simple access through the post_training API
- Fully integrated with Llama Stack
- GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders)
### Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `device` | `str` | No | cuda | |
| `distributed_backend` | `Literal['fsdp', 'deepspeed']` | No | | |
| `checkpoint_format` | `Literal['full_state', 'huggingface']` | No | huggingface | |
| `chat_template` | `str` | No | |
| `model_specific_config` | `dict` | No | `{'trust_remote_code': True, 'attn_implementation': 'sdpa'}` | |
| `max_seq_length` | `int` | No | 2048 | |
| `gradient_checkpointing` | `bool` | No | False | |
| `save_total_limit` | `int` | No | 3 | |
| `logging_steps` | `int` | No | 10 | |
| `warmup_ratio` | `float` | No | 0.1 | |
| `weight_decay` | `float` | No | 0.01 | |
| `dataloader_num_workers` | `int` | No | 4 | |
| `dataloader_pin_memory` | `bool` | No | True | |
### Sample Configuration
```yaml
checkpoint_format: huggingface
distributed_backend: null
device: cpu
```
### Setup
You can access the HuggingFace trainer via the `starter` distribution:
```bash
llama stack build --distro starter --image-type venv
llama stack run --image-type venv ~/.llama/distributions/starter/starter-run.yaml
```
### Usage Example
```python
import time
import uuid
from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")
client = create_http_client()
# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=32,
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
gradient_accumulation_steps=1,
max_steps_per_epoch=0,
max_validation_steps=1,
n_epochs=4,
)
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
alpha=1,
apply_lora_to_mlp=True,
apply_lora_to_output=False,
lora_attn_modules=["q_proj"],
rank=1,
type="LoRA",
)
job_uuid = f"test-job{uuid.uuid4()}"
# Example Model
training_model = "ibm-granite/granite-3.3-8b-instruct"
start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)
# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break
print(status)
if status.status == "completed":
break
print("Waiting for job to complete...")
time.sleep(5)
end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")
print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))
```
## TorchTune
[TorchTune](https://github.com/pytorch/torchtune) is an inline post training provider for Llama Stack. It provides a simple and efficient way to fine-tune language models using PyTorch.
### Features
- Simple access through the post_training API
- Fully integrated with Llama Stack
- GPU support and single device capabilities
- Support for LoRA
### Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `torch_seed` | `int \| None` | No | | |
| `checkpoint_format` | `Literal['meta', 'huggingface']` | No | meta | |
### Sample Configuration
```yaml
checkpoint_format: meta
```
### Setup
You can access the TorchTune trainer by writing your own yaml pointing to the provider:
```yaml
post_training:
- provider_id: torchtune
provider_type: inline::torchtune
config: {}
```
You can then build and run your own stack with this provider.
### Usage Example
```python
import time
import uuid
from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")
client = create_http_client()
# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=32,
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
gradient_accumulation_steps=1,
max_steps_per_epoch=0,
max_validation_steps=1,
n_epochs=4,
)
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
alpha=1,
apply_lora_to_mlp=True,
apply_lora_to_output=False,
lora_attn_modules=["q_proj"],
rank=1,
type="LoRA",
)
job_uuid = f"test-job{uuid.uuid4()}"
# Example Model
training_model = "meta-llama/Llama-2-7b-hf"
start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)
# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break
print(status)
if status.status == "completed":
break
print("Waiting for job to complete...")
time.sleep(5)
end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")
print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))
```
## NVIDIA
NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.
### Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `api_key` | `str \| None` | No | | The NVIDIA API key. |
| `dataset_namespace` | `str \| None` | No | default | The NVIDIA dataset namespace. |
| `project_id` | `str \| None` | No | test-example-model@v1 | The NVIDIA project ID. |
| `customizer_url` | `str \| None` | No | | Base URL for the NeMo Customizer API |
| `timeout` | `int` | No | 300 | Timeout for the NVIDIA Post Training API |
| `max_retries` | `int` | No | 3 | Maximum number of retries for the NVIDIA Post Training API |
| `output_model_dir` | `str` | No | test-example-model@v1 | Directory to save the output model |
### Sample Configuration
```yaml
api_key: ${env.NVIDIA_API_KEY:=}
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}
```
## Best Practices
- **Choose the right provider**: Use HuggingFace for broader compatibility, TorchTune for Meta models, or NVIDIA for their ecosystem
- **Configure hardware appropriately**: Ensure your configuration matches your available hardware (CPU, GPU, MPS)
- **Monitor jobs**: Always monitor job status and handle completion appropriately
- **Use appropriate datasets**: Ensure your dataset format matches the expected input format for your chosen provider
## Next Steps
- Check out the [Building Applications - Fine-tuning](../building_applications/index.mdx) guide for application-level examples
- See the [Providers](../providers/post_training/index.mdx) section for detailed provider documentation
- Review the [API Reference](../advanced_apis/post_training.mdx) for complete API documentation