feat: add huggingface post_training impl (#2132)

# What does this PR do?


adds an inline HF SFTTrainer provider. Alongside touchtune -- this is a
super popular option for running training jobs. The config allows a user
to specify some key fields such as a model, chat_template, device, etc

the provider comes with one recipe `finetune_single_device` which works
both with and without LoRA.

any model that is a valid HF identifier can be given and the model will
be pulled.

this has been tested so far with CPU and MPS device types, but should be
compatible with CUDA out of the box

The provider processes the given dataset into the proper format,
establishes the various steps per epoch, steps per save, steps per eval,
sets a sane SFTConfig, and runs n_epochs of training

if checkpoint_dir is none, no model is saved. If there is a checkpoint
dir, a model is saved every `save_steps` and at the end of training.


## Test Plan

re-enabled post_training integration test suite with a singular test
that loads the simpleqa dataset:
https://huggingface.co/datasets/llamastack/simpleqa and a tiny granite
model: https://huggingface.co/ibm-granite/granite-3.3-2b-instruct. The
test now uses the llama stack client and the proper post_training API

runs one step with a batch_size of 1. This test runs on CPU on the
Ubuntu runner so it needs to be a small batch and a single step.

[//]: # (## Documentation)

---------

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-05-16 17:41:28 -04:00 committed by GitHub
parent 8f9964f46b
commit f02f7b28c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 1181 additions and 201 deletions

View file

@ -441,6 +441,7 @@
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"peft",
"pillow",
"psycopg2-binary",
"pymongo",
@ -451,9 +452,11 @@
"scikit-learn",
"scipy",
"sentencepiece",
"torch",
"tqdm",
"transformers",
"tree_sitter",
"trl",
"uvicorn"
],
"open-benchmark": [

View file

@ -13,9 +13,10 @@ distribution_spec:
- inline::basic
- inline::braintrust
post_training:
- inline::torchtune
- inline::huggingface
datasetio:
- inline::localfs
- remote::huggingface
telemetry:
- inline::meta-reference
agents:

View file

@ -49,16 +49,24 @@ providers:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/localfs_datasetio.db
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/huggingface}/huggingface_datasetio.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
post_training:
- provider_id: torchtune-post-training
provider_type: inline::torchtune
config: {
- provider_id: huggingface
provider_type: inline::huggingface
config:
checkpoint_format: huggingface
}
distributed_backend: null
device: cpu
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -23,6 +23,8 @@ distribution_spec:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
post_training:
- inline::huggingface
tool_runtime:
- remote::brave-search
- remote::tavily-search

View file

@ -13,6 +13,7 @@ from llama_stack.distribution.datatypes import (
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -28,6 +29,7 @@ def get_distribution_template() -> DistributionTemplate:
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"post_training": ["inline::huggingface"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
@ -47,7 +49,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
posttraining_provider = Provider(
provider_id="huggingface",
provider_type="inline::huggingface",
config=HuggingFacePostTrainingConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="ollama",
@ -92,6 +98,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_overrides={
"inference": [inference_provider],
"vector_io": [vector_io_provider_faiss],
"post_training": [posttraining_provider],
},
default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups,
@ -100,6 +107,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_overrides={
"inference": [inference_provider],
"vector_io": [vector_io_provider_faiss],
"post_training": [posttraining_provider],
"safety": [
Provider(
provider_id="llama-guard",

View file

@ -5,6 +5,7 @@ apis:
- datasetio
- eval
- inference
- post_training
- safety
- scoring
- telemetry
@ -80,6 +81,13 @@ providers:
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
post_training:
- provider_id: huggingface
provider_type: inline::huggingface
config:
checkpoint_format: huggingface
distributed_backend: null
device: cpu
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search

View file

@ -5,6 +5,7 @@ apis:
- datasetio
- eval
- inference
- post_training
- safety
- scoring
- telemetry
@ -78,6 +79,13 @@ providers:
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
post_training:
- provider_id: huggingface
provider_type: inline::huggingface
config:
checkpoint_format: huggingface
distributed_backend: null
device: cpu
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search