mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
* Removes a bunch of distros * Removed distros were added into the "starter" distribution * Doc for "starter" has been added * Partially reverts https://github.com/meta-llama/llama-stack/pull/2482 since inference providers are disabled by default and can be turned on manually via env variable. * Disables safety in starter distro Closes: #2502 Signed-off-by: Sébastien Han <seb@redhat.com>
3 KiB
3 KiB
orphan |
---|
true |
HuggingFace SFTTrainer
HuggingFace SFTTrainer is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets
Features
- Simple access through the post_training API
- Fully integrated with Llama Stack
- GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders)
Usage
To use the HF SFTTrainer in your Llama Stack project, follow these steps:
- Configure your Llama Stack project to use this provider.
- Kick off a SFT job using the Llama Stack post_training API.
Setup
You can access the HuggingFace trainer via the ollama
distribution:
llama stack build --template starter --image-type venv
llama stack run --image-type venv ~/.llama/distributions/ollama/ollama-run.yaml
Run Training
You can access the provider and the supervised_fine_tune
method via the post_training API:
import time
import uuid
from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")
client = create_http_client()
# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=32,
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
gradient_accumulation_steps=1,
max_steps_per_epoch=0,
max_validation_steps=1,
n_epochs=4,
)
algorithm_config = algorithm_config_param.LoraFinetuningConfig( # this config is also currently mandatory but should not be
alpha=1,
apply_lora_to_mlp=True,
apply_lora_to_output=False,
lora_attn_modules=["q_proj"],
rank=1,
type="LoRA",
)
job_uuid = f"test-job{uuid.uuid4()}"
# Example Model
training_model = "ibm-granite/granite-3.3-8b-instruct"
start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)
# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break
print(status)
if status.status == "completed":
break
print("Waiting for job to complete...")
time.sleep(5)
end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")
print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))