mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> - Fixes broken links and Docusaurus search Closes #3518 ## Test Plan The following should produce a clean build with no warnings and search enabled: ``` npm install npm run gen-api-docs all npm run build npm run serve ``` <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* -->
305 lines
8.8 KiB
Text
305 lines
8.8 KiB
Text
# Post-Training
|
|
|
|
Post-training in Llama Stack allows you to fine-tune models using various providers and frameworks. This section covers all available post-training providers and how to use them effectively.
|
|
|
|
## Overview
|
|
|
|
Llama Stack provides multiple post-training providers:
|
|
|
|
- **HuggingFace SFTTrainer** (`inline::huggingface`) - Fine-tuning using HuggingFace ecosystem
|
|
- **TorchTune** (`inline::torchtune`) - Fine-tuning using Meta's TorchTune framework
|
|
- **NVIDIA** (`remote::nvidia`) - Fine-tuning using NVIDIA's platform
|
|
|
|
## HuggingFace SFTTrainer
|
|
|
|
[HuggingFace SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets.
|
|
|
|
### Features
|
|
|
|
- Simple access through the post_training API
|
|
- Fully integrated with Llama Stack
|
|
- GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders)
|
|
|
|
### Configuration
|
|
|
|
| Field | Type | Required | Default | Description |
|
|
|-------|------|----------|---------|-------------|
|
|
| `device` | `str` | No | cuda | |
|
|
| `distributed_backend` | `Literal['fsdp', 'deepspeed']` | No | | |
|
|
| `checkpoint_format` | `Literal['full_state', 'huggingface']` | No | huggingface | |
|
|
| `chat_template` | `str` | No | |
|
|
| `model_specific_config` | `dict` | No | `{'trust_remote_code': True, 'attn_implementation': 'sdpa'}` | |
|
|
| `max_seq_length` | `int` | No | 2048 | |
|
|
| `gradient_checkpointing` | `bool` | No | False | |
|
|
| `save_total_limit` | `int` | No | 3 | |
|
|
| `logging_steps` | `int` | No | 10 | |
|
|
| `warmup_ratio` | `float` | No | 0.1 | |
|
|
| `weight_decay` | `float` | No | 0.01 | |
|
|
| `dataloader_num_workers` | `int` | No | 4 | |
|
|
| `dataloader_pin_memory` | `bool` | No | True | |
|
|
|
|
### Sample Configuration
|
|
|
|
```yaml
|
|
checkpoint_format: huggingface
|
|
distributed_backend: null
|
|
device: cpu
|
|
```
|
|
|
|
### Setup
|
|
|
|
You can access the HuggingFace trainer via the `starter` distribution:
|
|
|
|
```bash
|
|
llama stack build --distro starter --image-type venv
|
|
llama stack run --image-type venv ~/.llama/distributions/starter/starter-run.yaml
|
|
```
|
|
|
|
### Usage Example
|
|
|
|
```python
|
|
import time
|
|
import uuid
|
|
|
|
from llama_stack_client.types import (
|
|
post_training_supervised_fine_tune_params,
|
|
algorithm_config_param,
|
|
)
|
|
|
|
def create_http_client():
|
|
from llama_stack_client import LlamaStackClient
|
|
return LlamaStackClient(base_url="http://localhost:8321")
|
|
|
|
client = create_http_client()
|
|
|
|
# Example Dataset
|
|
client.datasets.register(
|
|
purpose="post-training/messages",
|
|
source={
|
|
"type": "uri",
|
|
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
|
},
|
|
dataset_id="simpleqa",
|
|
)
|
|
|
|
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
|
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
|
batch_size=32,
|
|
data_format="instruct",
|
|
dataset_id="simpleqa",
|
|
shuffle=True,
|
|
),
|
|
gradient_accumulation_steps=1,
|
|
max_steps_per_epoch=0,
|
|
max_validation_steps=1,
|
|
n_epochs=4,
|
|
)
|
|
|
|
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
|
|
alpha=1,
|
|
apply_lora_to_mlp=True,
|
|
apply_lora_to_output=False,
|
|
lora_attn_modules=["q_proj"],
|
|
rank=1,
|
|
type="LoRA",
|
|
)
|
|
|
|
job_uuid = f"test-job{uuid.uuid4()}"
|
|
|
|
# Example Model
|
|
training_model = "ibm-granite/granite-3.3-8b-instruct"
|
|
|
|
start_time = time.time()
|
|
response = client.post_training.supervised_fine_tune(
|
|
job_uuid=job_uuid,
|
|
logger_config={},
|
|
model=training_model,
|
|
hyperparam_search_config={},
|
|
training_config=training_config,
|
|
algorithm_config=algorithm_config,
|
|
checkpoint_dir="output",
|
|
)
|
|
print("Job: ", job_uuid)
|
|
|
|
# Wait for the job to complete!
|
|
while True:
|
|
status = client.post_training.job.status(job_uuid=job_uuid)
|
|
if not status:
|
|
print("Job not found")
|
|
break
|
|
|
|
print(status)
|
|
if status.status == "completed":
|
|
break
|
|
|
|
print("Waiting for job to complete...")
|
|
time.sleep(5)
|
|
|
|
end_time = time.time()
|
|
print("Job completed in", end_time - start_time, "seconds!")
|
|
|
|
print("Artifacts:")
|
|
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
|
```
|
|
|
|
## TorchTune
|
|
|
|
[TorchTune](https://github.com/pytorch/torchtune) is an inline post training provider for Llama Stack. It provides a simple and efficient way to fine-tune language models using PyTorch.
|
|
|
|
### Features
|
|
|
|
- Simple access through the post_training API
|
|
- Fully integrated with Llama Stack
|
|
- GPU support and single device capabilities
|
|
- Support for LoRA
|
|
|
|
### Configuration
|
|
|
|
| Field | Type | Required | Default | Description |
|
|
|-------|------|----------|---------|-------------|
|
|
| `torch_seed` | `int \| None` | No | | |
|
|
| `checkpoint_format` | `Literal['meta', 'huggingface']` | No | meta | |
|
|
|
|
### Sample Configuration
|
|
|
|
```yaml
|
|
checkpoint_format: meta
|
|
```
|
|
|
|
### Setup
|
|
|
|
You can access the TorchTune trainer by writing your own yaml pointing to the provider:
|
|
|
|
```yaml
|
|
post_training:
|
|
- provider_id: torchtune
|
|
provider_type: inline::torchtune
|
|
config: {}
|
|
```
|
|
|
|
You can then build and run your own stack with this provider.
|
|
|
|
### Usage Example
|
|
|
|
```python
|
|
import time
|
|
import uuid
|
|
|
|
from llama_stack_client.types import (
|
|
post_training_supervised_fine_tune_params,
|
|
algorithm_config_param,
|
|
)
|
|
|
|
def create_http_client():
|
|
from llama_stack_client import LlamaStackClient
|
|
return LlamaStackClient(base_url="http://localhost:8321")
|
|
|
|
client = create_http_client()
|
|
|
|
# Example Dataset
|
|
client.datasets.register(
|
|
purpose="post-training/messages",
|
|
source={
|
|
"type": "uri",
|
|
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
|
},
|
|
dataset_id="simpleqa",
|
|
)
|
|
|
|
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
|
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
|
batch_size=32,
|
|
data_format="instruct",
|
|
dataset_id="simpleqa",
|
|
shuffle=True,
|
|
),
|
|
gradient_accumulation_steps=1,
|
|
max_steps_per_epoch=0,
|
|
max_validation_steps=1,
|
|
n_epochs=4,
|
|
)
|
|
|
|
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
|
|
alpha=1,
|
|
apply_lora_to_mlp=True,
|
|
apply_lora_to_output=False,
|
|
lora_attn_modules=["q_proj"],
|
|
rank=1,
|
|
type="LoRA",
|
|
)
|
|
|
|
job_uuid = f"test-job{uuid.uuid4()}"
|
|
|
|
# Example Model
|
|
training_model = "meta-llama/Llama-2-7b-hf"
|
|
|
|
start_time = time.time()
|
|
response = client.post_training.supervised_fine_tune(
|
|
job_uuid=job_uuid,
|
|
logger_config={},
|
|
model=training_model,
|
|
hyperparam_search_config={},
|
|
training_config=training_config,
|
|
algorithm_config=algorithm_config,
|
|
checkpoint_dir="output",
|
|
)
|
|
print("Job: ", job_uuid)
|
|
|
|
# Wait for the job to complete!
|
|
while True:
|
|
status = client.post_training.job.status(job_uuid=job_uuid)
|
|
if not status:
|
|
print("Job not found")
|
|
break
|
|
|
|
print(status)
|
|
if status.status == "completed":
|
|
break
|
|
|
|
print("Waiting for job to complete...")
|
|
time.sleep(5)
|
|
|
|
end_time = time.time()
|
|
print("Job completed in", end_time - start_time, "seconds!")
|
|
|
|
print("Artifacts:")
|
|
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
|
```
|
|
|
|
## NVIDIA
|
|
|
|
NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.
|
|
|
|
### Configuration
|
|
|
|
| Field | Type | Required | Default | Description |
|
|
|-------|------|----------|---------|-------------|
|
|
| `api_key` | `str \| None` | No | | The NVIDIA API key. |
|
|
| `dataset_namespace` | `str \| None` | No | default | The NVIDIA dataset namespace. |
|
|
| `project_id` | `str \| None` | No | test-example-model@v1 | The NVIDIA project ID. |
|
|
| `customizer_url` | `str \| None` | No | | Base URL for the NeMo Customizer API |
|
|
| `timeout` | `int` | No | 300 | Timeout for the NVIDIA Post Training API |
|
|
| `max_retries` | `int` | No | 3 | Maximum number of retries for the NVIDIA Post Training API |
|
|
| `output_model_dir` | `str` | No | test-example-model@v1 | Directory to save the output model |
|
|
|
|
### Sample Configuration
|
|
|
|
```yaml
|
|
api_key: ${env.NVIDIA_API_KEY:=}
|
|
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
|
|
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
|
|
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}
|
|
```
|
|
|
|
## Best Practices
|
|
|
|
- **Choose the right provider**: Use HuggingFace for broader compatibility, TorchTune for Meta models, or NVIDIA for their ecosystem
|
|
- **Configure hardware appropriately**: Ensure your configuration matches your available hardware (CPU, GPU, MPS)
|
|
- **Monitor jobs**: Always monitor job status and handle completion appropriately
|
|
- **Use appropriate datasets**: Ensure your dataset format matches the expected input format for your chosen provider
|
|
|
|
## Next Steps
|
|
|
|
- Check out the [Building Applications - Fine-tuning](../building_applications/index.mdx) guide for application-level examples
|
|
- See the [Providers](../providers/post_training/index.mdx) section for detailed provider documentation
|
|
- Review the [API Reference](../advanced_apis/post_training.mdx) for complete API documentation
|