llama-stack/docs/source/providers/post_training/torchtune.md
Charlie Doern a7ecc92be1
docs: add post training to providers list (#2280)
# What does this PR do?

the providers list is missing post_training. Add that column and
`HuggingFace`, `TorchTune`, and `NVIDIA NEMO` as supported providers.

also point to these providers in docs/source/providers/index.md, and
describe basic functionality

There are other missing provider types here as well, but starting with
this

Signed-off-by: Charlie Doern <cdoern@redhat.com>
Co-authored-by: Francisco Arceo <arceofrancisco@gmail.com>
2025-05-28 09:32:00 -04:00

125 lines
2.9 KiB
Markdown

---
orphan: true
---
# TorchTune
[TorchTune](https://github.com/pytorch/torchtune) is an inline post training provider for Llama Stack. It provides a simple and efficient way to fine-tune language models using PyTorch.
## Features
- Simple access through the post_training API
- Fully integrated with Llama Stack
- GPU support and single device capabilities.
- Support for LoRA
## Usage
To use TorchTune in your Llama Stack project, follow these steps:
1. Configure your Llama Stack project to use this provider.
2. Kick off a fine-tuning job using the Llama Stack post_training API.
## Setup
You can access the TorchTune trainer by writing your own yaml pointing to the provider:
```yaml
post_training:
- provider_id: torchtune
provider_type: inline::torchtune
config: {}
```
you can then build and run your own stack with this provider.
## Run Training
You can access the provider and the `supervised_fine_tune` method via the post_training API:
```python
import time
import uuid
from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")
client = create_http_client()
# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=32,
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
gradient_accumulation_steps=1,
max_steps_per_epoch=0,
max_validation_steps=1,
n_epochs=4,
)
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
alpha=1,
apply_lora_to_mlp=True,
apply_lora_to_output=False,
lora_attn_modules=["q_proj"],
rank=1,
type="LoRA",
)
job_uuid = f"test-job{uuid.uuid4()}"
# Example Model
training_model = "meta-llama/Llama-2-7b-hf"
start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)
# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break
print(status)
if status.status == "completed":
break
print("Waiting for job to complete...")
time.sleep(5)
end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")
print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))
```