llama-stack/docs/source/providers/post_training/huggingface.md
Charlie Doern a7ecc92be1
docs: add post training to providers list (#2280)
# What does this PR do?

the providers list is missing post_training. Add that column and
`HuggingFace`, `TorchTune`, and `NVIDIA NEMO` as supported providers.

also point to these providers in docs/source/providers/index.md, and
describe basic functionality

There are other missing provider types here as well, but starting with
this

Signed-off-by: Charlie Doern <cdoern@redhat.com>
Co-authored-by: Francisco Arceo <arceofrancisco@gmail.com>
2025-05-28 09:32:00 -04:00

3 KiB

orphan
true

HuggingFace SFTTrainer

HuggingFace SFTTrainer is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets

Features

  • Simple access through the post_training API
  • Fully integrated with Llama Stack
  • GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders)

Usage

To use the HF SFTTrainer in your Llama Stack project, follow these steps:

  1. Configure your Llama Stack project to use this provider.
  2. Kick off a SFT job using the Llama Stack post_training API.

Setup

You can access the HuggingFace trainer via the ollama distribution:

llama stack build --template ollama --image-type venv
llama stack run --image-type venv ~/.llama/distributions/ollama/ollama-run.yaml

Run Training

You can access the provider and the supervised_fine_tune method via the post_training API:

import time
import uuid


from llama_stack_client.types import (
    post_training_supervised_fine_tune_params,
    algorithm_config_param,
)


def create_http_client():
    from llama_stack_client import LlamaStackClient

    return LlamaStackClient(base_url="http://localhost:8321")


client = create_http_client()

# Example Dataset
client.datasets.register(
    purpose="post-training/messages",
    source={
        "type": "uri",
        "uri": "huggingface://datasets/llamastack/simpleqa?split=train",
    },
    dataset_id="simpleqa",
)

training_config = post_training_supervised_fine_tune_params.TrainingConfig(
    data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
        batch_size=32,
        data_format="instruct",
        dataset_id="simpleqa",
        shuffle=True,
    ),
    gradient_accumulation_steps=1,
    max_steps_per_epoch=0,
    max_validation_steps=1,
    n_epochs=4,
)

algorithm_config = algorithm_config_param.LoraFinetuningConfig(  # this config is also currently mandatory but should not be
    alpha=1,
    apply_lora_to_mlp=True,
    apply_lora_to_output=False,
    lora_attn_modules=["q_proj"],
    rank=1,
    type="LoRA",
)

job_uuid = f"test-job{uuid.uuid4()}"

# Example Model
training_model = "ibm-granite/granite-3.3-8b-instruct"

start_time = time.time()
response = client.post_training.supervised_fine_tune(
    job_uuid=job_uuid,
    logger_config={},
    model=training_model,
    hyperparam_search_config={},
    training_config=training_config,
    algorithm_config=algorithm_config,
    checkpoint_dir="output",
)
print("Job: ", job_uuid)


# Wait for the job to complete!
while True:
    status = client.post_training.job.status(job_uuid=job_uuid)
    if not status:
        print("Job not found")
        break

    print(status)
    if status.status == "completed":
        break

    print("Waiting for job to complete...")
    time.sleep(5)

end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")

print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))