mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: [post training] support save hf safetensor format checkpoint (#845)
## context Now, in llama stack, we only support inference / eval a finetuned checkpoint with meta-reference as inference provider. This is sub-optimal since meta-reference is pretty slow. Our vision is that developer can inference / eval a finetuned checkpoint produced by post training apis with all the inference providers on the stack. To achieve this, we'd like to define an unified output checkpoint format for post training providers. So that, all the inference provider can respect that format for customized model inference. By spotting check how [ollama](https://github.com/ollama/ollama/blob/main/docs/import.md) and [fireworks](https://docs.fireworks.ai/models/uploading-custom-models) do inference on a customized model, we defined the output checkpoint format as /adapter/adapter_config.json and /adapter/adapter_model.safetensors (as we only support LoRA post training now, we begin from adapter only checkpoint) ## test we kick off a post training job and configured checkpoint format as 'huggingface'. Output files  we did a proof of concept with ollama to see if ollama can inference our finetuned checkpoint 1. create Modelfile like <img width="799" alt="Screenshot 2025-01-22 at 5 04 18 PM" src="https://github.com/user-attachments/assets/7fca9ac3-a294-44f8-aab1-83852c600609" /> 2. create a customized model with `ollama create llama_3_2_finetuned` and run inference successfully  This is just a proof of concept with ollama cmd line. As next step, we'd like to wrap loading / inference customized model logic in the inference provider implementation.
This commit is contained in:
parent
63e6acd0c3
commit
123fb9eb24
6 changed files with 6545 additions and 10 deletions
6434
docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb
Normal file
6434
docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
@ -4,15 +4,25 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from torchtune import training
|
||||
from torchtune.models import convert_weights
|
||||
from torchtune.training.checkpointing._utils import ModelType, safe_torch_load
|
||||
from torchtune.training.checkpointing._utils import (
|
||||
ADAPTER_CONFIG_FNAME,
|
||||
ADAPTER_MODEL_FNAME,
|
||||
REPO_ID_FNAME,
|
||||
SUFFIXES_TO_NOT_COPY,
|
||||
ModelType,
|
||||
copy_files,
|
||||
safe_torch_load,
|
||||
)
|
||||
from torchtune.utils._logging import get_logger
|
||||
|
||||
logger = get_logger("DEBUG")
|
||||
|
@ -75,9 +85,24 @@ class TorchtuneCheckpointer:
|
|||
state_dict: Dict[str, Any],
|
||||
epoch: int,
|
||||
adapter_only: bool = False,
|
||||
checkpoint_format: str = "meta",
|
||||
) -> str:
|
||||
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||
if checkpoint_format == "meta":
|
||||
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
|
||||
elif checkpoint_format == "huggingface":
|
||||
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
|
||||
self._save_hf_format_checkpoint(model_file_path, state_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported checkpoint format: {format}")
|
||||
return str(model_file_path)
|
||||
|
||||
def _save_meta_format_checkpoint(
|
||||
self,
|
||||
model_file_path: Path,
|
||||
state_dict: Dict[str, Any],
|
||||
adapter_only: bool = False,
|
||||
) -> None:
|
||||
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# copy the related files for inference
|
||||
|
@ -140,6 +165,76 @@ class TorchtuneCheckpointer:
|
|||
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||
)
|
||||
|
||||
print("model_file_path", str(model_file_path))
|
||||
def _save_hf_format_checkpoint(
|
||||
self,
|
||||
model_file_path: Path,
|
||||
state_dict: Dict[str, Any],
|
||||
) -> None:
|
||||
# the config.json file contains model params needed for state dict conversion
|
||||
config = json.loads(Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text())
|
||||
|
||||
return str(model_file_path)
|
||||
# repo_id is necessary for when saving an adapter config, so its compatible with HF.
|
||||
# This json file is produced and saved in the download step.
|
||||
# contents are {"repo_id": "some_model/some_model_version"}
|
||||
repo_id_path = Path.joinpath(self._checkpoint_dir.parent, REPO_ID_FNAME).with_suffix(".json")
|
||||
self.repo_id = None
|
||||
if repo_id_path.exists():
|
||||
with open(repo_id_path, "r") as json_file:
|
||||
data = json.load(json_file)
|
||||
self.repo_id = data.get("repo_id")
|
||||
|
||||
if training.ADAPTER_KEY in state_dict:
|
||||
# TODO: saving it "as is" is a requirement because, if we only save with
|
||||
# convert_weights.tune_to_peft_adapter_weights, we do NOT have a fn
|
||||
# convert_weights.peft_to_tune. The .pt format is not needed, but
|
||||
# it is an easy way to distinguish the adapters. Ideally we should save only one.
|
||||
output_path = Path.joinpath(model_file_path, ADAPTER_MODEL_FNAME).with_suffix(".pt")
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(state_dict[training.ADAPTER_KEY], output_path)
|
||||
logger.info(
|
||||
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
|
||||
)
|
||||
|
||||
state_dict[training.ADAPTER_KEY] = convert_weights.tune_to_peft_adapter_weights(
|
||||
state_dict[training.ADAPTER_KEY],
|
||||
num_heads=config["num_attention_heads"],
|
||||
num_kv_heads=config["num_key_value_heads"],
|
||||
dim=config["hidden_size"],
|
||||
head_dim=config.get("head_dim", None),
|
||||
)
|
||||
output_path = Path.joinpath(model_file_path, "adapter", ADAPTER_MODEL_FNAME)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path = output_path.with_suffix(".safetensors")
|
||||
save_file(
|
||||
state_dict[training.ADAPTER_KEY],
|
||||
output_path,
|
||||
metadata={"format": "pt"},
|
||||
)
|
||||
logger.info(
|
||||
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||
)
|
||||
|
||||
if training.ADAPTER_CONFIG in state_dict:
|
||||
state_dict[training.ADAPTER_CONFIG] = convert_weights.tune_to_peft_adapter_config(
|
||||
adapter_config=state_dict[training.ADAPTER_CONFIG],
|
||||
base_model_name_or_path=self.repo_id,
|
||||
)
|
||||
|
||||
output_path = Path.joinpath(model_file_path, "adapter", ADAPTER_CONFIG_FNAME).with_suffix(".json")
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(state_dict[training.ADAPTER_CONFIG], f)
|
||||
logger.info(
|
||||
f"Adapter checkpoint of size {os.path.getsize(output_path) / 1024**3:.2f} GiB saved to {output_path}"
|
||||
)
|
||||
|
||||
# Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch}
|
||||
# So its easy to run inference with the model using this epoch's checkpoint
|
||||
copy_files(
|
||||
self._checkpoint_dir.parent,
|
||||
model_file_path,
|
||||
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
|
||||
)
|
||||
|
|
|
@ -4,10 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TorchtunePostTrainingConfig(BaseModel):
|
||||
torch_seed: Optional[int] = None
|
||||
checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta"
|
||||
|
|
|
@ -117,6 +117,7 @@ class LoraFinetuningSingleDevice:
|
|||
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||
|
||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||
self._checkpoint_format = config.checkpoint_format
|
||||
|
||||
self.seed = training.set_seed(seed=config.torch_seed)
|
||||
self.epochs_run = 0
|
||||
|
@ -419,6 +420,7 @@ class LoraFinetuningSingleDevice:
|
|||
return self._checkpointer.save_checkpoint(
|
||||
ckpt_dict,
|
||||
epoch=epoch,
|
||||
checkpoint_format=self._checkpoint_format,
|
||||
)
|
||||
|
||||
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
|
@ -460,7 +462,7 @@ class LoraFinetuningSingleDevice:
|
|||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||
# Update the sampler to ensure data is correctly shuffled across epochs
|
||||
# in case shuffle is True
|
||||
metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}")
|
||||
metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log")
|
||||
self._training_sampler.set_epoch(curr_epoch)
|
||||
loss_to_log = 0.0
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ distribution_spec:
|
|||
providers:
|
||||
inference:
|
||||
- inline::meta-reference
|
||||
- remote::ollama
|
||||
eval:
|
||||
- inline::meta-reference
|
||||
scoring:
|
||||
|
@ -15,7 +16,6 @@ distribution_spec:
|
|||
- inline::torchtune
|
||||
datasetio:
|
||||
- inline::localfs
|
||||
- remote::huggingface
|
||||
telemetry:
|
||||
- inline::meta-reference
|
||||
agents:
|
||||
|
|
|
@ -21,6 +21,10 @@ providers:
|
|||
max_seq_len: 4096
|
||||
checkpoint_dir: null
|
||||
create_distributed_process_group: False
|
||||
- provider_id: ollama
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:http://localhost:11434}
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
@ -34,9 +38,6 @@ providers:
|
|||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
datasetio:
|
||||
- provider_id: huggingface-0
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
|
@ -47,7 +48,9 @@ providers:
|
|||
post_training:
|
||||
- provider_id: torchtune-post-training
|
||||
provider_type: inline::torchtune
|
||||
config: {}
|
||||
config: {
|
||||
checkpoint_format: huggingface
|
||||
}
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue