feat: [post training] support save hf safetensor format checkpoint (#845)

## context

Now, in llama stack, we only support inference / eval a finetuned
checkpoint with meta-reference as inference provider. This is
sub-optimal since meta-reference is pretty slow.

Our vision is that developer can inference / eval a finetuned checkpoint
produced by post training apis with all the inference providers on the
stack. To achieve this, we'd like to define an unified output checkpoint
format for post training providers. So that, all the inference provider
can respect that format for customized model inference.

By spotting check how
[ollama](https://github.com/ollama/ollama/blob/main/docs/import.md) and
[fireworks](https://docs.fireworks.ai/models/uploading-custom-models) do
inference on a customized model, we defined the output checkpoint format
as /adapter/adapter_config.json and /adapter/adapter_model.safetensors
(as we only support LoRA post training now, we begin from adapter only
checkpoint)

## test
we kick off a post training job and configured checkpoint format as
'huggingface'. Output files
![Screenshot 2025-02-24 at 11 54
33 PM](https://github.com/user-attachments/assets/fb45a5d7-f288-4d30-82f8-b7a8da2859be)



we did a proof of concept with ollama to see if ollama can inference our
finetuned checkpoint
1. create Modelfile like 

<img width="799" alt="Screenshot 2025-01-22 at 5 04 18 PM"
src="https://github.com/user-attachments/assets/7fca9ac3-a294-44f8-aab1-83852c600609"
/>

2. create a customized model with `ollama create llama_3_2_finetuned`
and run inference successfully

![Screenshot 2025-02-24 at 11 55
17 PM](https://github.com/user-attachments/assets/1abe7c52-c6a7-491a-b07c-b7a8e3fd1ddd)


This is just a proof of concept with ollama cmd line. As next step, we'd
like to wrap loading / inference customized model logic in the inference
provider implementation.
This commit is contained in:
Botao Chen 2025-02-25 23:29:08 -08:00 committed by GitHub
parent 63e6acd0c3
commit 123fb9eb24
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 6545 additions and 10 deletions

File diff suppressed because one or more lines are too long

View file

@ -4,15 +4,25 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
import os import os
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List from typing import Any, Dict, List
import torch import torch
from safetensors.torch import save_file
from torchtune import training from torchtune import training
from torchtune.models import convert_weights from torchtune.models import convert_weights
from torchtune.training.checkpointing._utils import ModelType, safe_torch_load from torchtune.training.checkpointing._utils import (
ADAPTER_CONFIG_FNAME,
ADAPTER_MODEL_FNAME,
REPO_ID_FNAME,
SUFFIXES_TO_NOT_COPY,
ModelType,
copy_files,
safe_torch_load,
)
from torchtune.utils._logging import get_logger from torchtune.utils._logging import get_logger
logger = get_logger("DEBUG") logger = get_logger("DEBUG")
@ -75,9 +85,24 @@ class TorchtuneCheckpointer:
state_dict: Dict[str, Any], state_dict: Dict[str, Any],
epoch: int, epoch: int,
adapter_only: bool = False, adapter_only: bool = False,
checkpoint_format: str = "meta",
) -> str: ) -> str:
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}" model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
if checkpoint_format == "meta":
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
elif checkpoint_format == "huggingface":
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
self._save_hf_format_checkpoint(model_file_path, state_dict)
else:
raise ValueError(f"Unsupported checkpoint format: {format}")
return str(model_file_path)
def _save_meta_format_checkpoint(
self,
model_file_path: Path,
state_dict: Dict[str, Any],
adapter_only: bool = False,
) -> None:
model_file_path.mkdir(parents=True, exist_ok=True) model_file_path.mkdir(parents=True, exist_ok=True)
# copy the related files for inference # copy the related files for inference
@ -140,6 +165,76 @@ class TorchtuneCheckpointer:
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights." "Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
) )
print("model_file_path", str(model_file_path)) def _save_hf_format_checkpoint(
self,
model_file_path: Path,
state_dict: Dict[str, Any],
) -> None:
# the config.json file contains model params needed for state dict conversion
config = json.loads(Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text())
return str(model_file_path) # repo_id is necessary for when saving an adapter config, so its compatible with HF.
# This json file is produced and saved in the download step.
# contents are {"repo_id": "some_model/some_model_version"}
repo_id_path = Path.joinpath(self._checkpoint_dir.parent, REPO_ID_FNAME).with_suffix(".json")
self.repo_id = None
if repo_id_path.exists():
with open(repo_id_path, "r") as json_file:
data = json.load(json_file)
self.repo_id = data.get("repo_id")
if training.ADAPTER_KEY in state_dict:
# TODO: saving it "as is" is a requirement because, if we only save with
# convert_weights.tune_to_peft_adapter_weights, we do NOT have a fn
# convert_weights.peft_to_tune. The .pt format is not needed, but
# it is an easy way to distinguish the adapters. Ideally we should save only one.
output_path = Path.joinpath(model_file_path, ADAPTER_MODEL_FNAME).with_suffix(".pt")
output_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(state_dict[training.ADAPTER_KEY], output_path)
logger.info(
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
)
state_dict[training.ADAPTER_KEY] = convert_weights.tune_to_peft_adapter_weights(
state_dict[training.ADAPTER_KEY],
num_heads=config["num_attention_heads"],
num_kv_heads=config["num_key_value_heads"],
dim=config["hidden_size"],
head_dim=config.get("head_dim", None),
)
output_path = Path.joinpath(model_file_path, "adapter", ADAPTER_MODEL_FNAME)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path = output_path.with_suffix(".safetensors")
save_file(
state_dict[training.ADAPTER_KEY],
output_path,
metadata={"format": "pt"},
)
logger.info(
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
)
else:
raise ValueError(
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)
if training.ADAPTER_CONFIG in state_dict:
state_dict[training.ADAPTER_CONFIG] = convert_weights.tune_to_peft_adapter_config(
adapter_config=state_dict[training.ADAPTER_CONFIG],
base_model_name_or_path=self.repo_id,
)
output_path = Path.joinpath(model_file_path, "adapter", ADAPTER_CONFIG_FNAME).with_suffix(".json")
with open(output_path, "w") as f:
json.dump(state_dict[training.ADAPTER_CONFIG], f)
logger.info(
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
)
# Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch}
# So its easy to run inference with the model using this epoch's checkpoint
copy_files(
self._checkpoint_dir.parent,
model_file_path,
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
)

View file

@ -4,10 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Optional from typing import Literal, Optional
from pydantic import BaseModel from pydantic import BaseModel
class TorchtunePostTrainingConfig(BaseModel): class TorchtunePostTrainingConfig(BaseModel):
torch_seed: Optional[int] = None torch_seed: Optional[int] = None
checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta"

View file

@ -117,6 +117,7 @@ class LoraFinetuningSingleDevice:
self.checkpoint_dir = model_checkpoint_dir(model) self.checkpoint_dir = model_checkpoint_dir(model)
self._output_dir = str(DEFAULT_CHECKPOINT_DIR) self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
self._checkpoint_format = config.checkpoint_format
self.seed = training.set_seed(seed=config.torch_seed) self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0 self.epochs_run = 0
@ -419,6 +420,7 @@ class LoraFinetuningSingleDevice:
return self._checkpointer.save_checkpoint( return self._checkpointer.save_checkpoint(
ckpt_dict, ckpt_dict,
epoch=epoch, epoch=epoch,
checkpoint_format=self._checkpoint_format,
) )
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
@ -460,7 +462,7 @@ class LoraFinetuningSingleDevice:
for curr_epoch in range(self.epochs_run, self.total_epochs): for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs # Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True # in case shuffle is True
metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}") metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log")
self._training_sampler.set_epoch(curr_epoch) self._training_sampler.set_epoch(curr_epoch)
loss_to_log = 0.0 loss_to_log = 0.0

View file

@ -6,6 +6,7 @@ distribution_spec:
providers: providers:
inference: inference:
- inline::meta-reference - inline::meta-reference
- remote::ollama
eval: eval:
- inline::meta-reference - inline::meta-reference
scoring: scoring:
@ -15,7 +16,6 @@ distribution_spec:
- inline::torchtune - inline::torchtune
datasetio: datasetio:
- inline::localfs - inline::localfs
- remote::huggingface
telemetry: telemetry:
- inline::meta-reference - inline::meta-reference
agents: agents:

View file

@ -21,6 +21,10 @@ providers:
max_seq_len: 4096 max_seq_len: 4096
checkpoint_dir: null checkpoint_dir: null
create_distributed_process_group: False create_distributed_process_group: False
- provider_id: ollama
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:http://localhost:11434}
eval: eval:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
@ -34,9 +38,6 @@ providers:
config: config:
openai_api_key: ${env.OPENAI_API_KEY:} openai_api_key: ${env.OPENAI_API_KEY:}
datasetio: datasetio:
- provider_id: huggingface-0
provider_type: remote::huggingface
config: {}
- provider_id: localfs - provider_id: localfs
provider_type: inline::localfs provider_type: inline::localfs
config: {} config: {}
@ -47,7 +48,9 @@ providers:
post_training: post_training:
- provider_id: torchtune-post-training - provider_id: torchtune-post-training
provider_type: inline::torchtune provider_type: inline::torchtune
config: {} config: {
checkpoint_format: huggingface
}
agents: agents:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference