forked from phoenix-oss/llama-stack-mirror
# What does this PR do? the providers list is missing post_training. Add that column and `HuggingFace`, `TorchTune`, and `NVIDIA NEMO` as supported providers. also point to these providers in docs/source/providers/index.md, and describe basic functionality There are other missing provider types here as well, but starting with this Signed-off-by: Charlie Doern <cdoern@redhat.com> Co-authored-by: Francisco Arceo <arceofrancisco@gmail.com>
122 lines
3 KiB
Markdown
122 lines
3 KiB
Markdown
---
|
|
orphan: true
|
|
---
|
|
# HuggingFace SFTTrainer
|
|
|
|
[HuggingFace SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets
|
|
|
|
## Features
|
|
|
|
- Simple access through the post_training API
|
|
- Fully integrated with Llama Stack
|
|
- GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders)
|
|
|
|
## Usage
|
|
|
|
To use the HF SFTTrainer in your Llama Stack project, follow these steps:
|
|
|
|
1. Configure your Llama Stack project to use this provider.
|
|
2. Kick off a SFT job using the Llama Stack post_training API.
|
|
|
|
## Setup
|
|
|
|
You can access the HuggingFace trainer via the `ollama` distribution:
|
|
|
|
```bash
|
|
llama stack build --template ollama --image-type venv
|
|
llama stack run --image-type venv ~/.llama/distributions/ollama/ollama-run.yaml
|
|
```
|
|
|
|
## Run Training
|
|
|
|
You can access the provider and the `supervised_fine_tune` method via the post_training API:
|
|
|
|
```python
|
|
import time
|
|
import uuid
|
|
|
|
|
|
from llama_stack_client.types import (
|
|
post_training_supervised_fine_tune_params,
|
|
algorithm_config_param,
|
|
)
|
|
|
|
|
|
def create_http_client():
|
|
from llama_stack_client import LlamaStackClient
|
|
|
|
return LlamaStackClient(base_url="http://localhost:8321")
|
|
|
|
|
|
client = create_http_client()
|
|
|
|
# Example Dataset
|
|
client.datasets.register(
|
|
purpose="post-training/messages",
|
|
source={
|
|
"type": "uri",
|
|
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
|
},
|
|
dataset_id="simpleqa",
|
|
)
|
|
|
|
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
|
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
|
batch_size=32,
|
|
data_format="instruct",
|
|
dataset_id="simpleqa",
|
|
shuffle=True,
|
|
),
|
|
gradient_accumulation_steps=1,
|
|
max_steps_per_epoch=0,
|
|
max_validation_steps=1,
|
|
n_epochs=4,
|
|
)
|
|
|
|
algorithm_config = algorithm_config_param.LoraFinetuningConfig( # this config is also currently mandatory but should not be
|
|
alpha=1,
|
|
apply_lora_to_mlp=True,
|
|
apply_lora_to_output=False,
|
|
lora_attn_modules=["q_proj"],
|
|
rank=1,
|
|
type="LoRA",
|
|
)
|
|
|
|
job_uuid = f"test-job{uuid.uuid4()}"
|
|
|
|
# Example Model
|
|
training_model = "ibm-granite/granite-3.3-8b-instruct"
|
|
|
|
start_time = time.time()
|
|
response = client.post_training.supervised_fine_tune(
|
|
job_uuid=job_uuid,
|
|
logger_config={},
|
|
model=training_model,
|
|
hyperparam_search_config={},
|
|
training_config=training_config,
|
|
algorithm_config=algorithm_config,
|
|
checkpoint_dir="output",
|
|
)
|
|
print("Job: ", job_uuid)
|
|
|
|
|
|
# Wait for the job to complete!
|
|
while True:
|
|
status = client.post_training.job.status(job_uuid=job_uuid)
|
|
if not status:
|
|
print("Job not found")
|
|
break
|
|
|
|
print(status)
|
|
if status.status == "completed":
|
|
break
|
|
|
|
print("Waiting for job to complete...")
|
|
time.sleep(5)
|
|
|
|
end_time = time.time()
|
|
print("Job completed in", end_time - start_time, "seconds!")
|
|
|
|
print("Artifacts:")
|
|
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
|
```
|