forked from phoenix-oss/llama-stack-mirror
[2/n][torchtune integration] implement job management and return training artifacts (#593)
### Context In this PR, we - Implement the post training job management and get training artifacts apis - get_training_jobs - get_training_job_status - get_training_job_artifacts - get_training_job_logstream is deleted since the trace can be directly accessed by UI with Jaeger https://llama-stack.readthedocs.io/en/latest/building_applications/telemetry.html#jaeger-to-visualize-traces - Refactor the post training and training types definition to make them more intuitive. - Rewrite the checkpointer to make it compatible with llama-stack file system and can be recognized during inference ### Test Unit test `pytest llama_stack/providers/tests/post_training/test_post_training.py -m "torchtune_post_training_huggingface_datasetio" -v -s --tb=short --disable-warnings` <img width="1506" alt="Screenshot 2024-12-10 at 4 06 17 PM" src="https://github.com/user-attachments/assets/16225029-bdb7-48c4-9d13-e580cc769c0a"> e2e test with client side call <img width="888" alt="Screenshot 2024-12-10 at 4 09 44 PM" src="https://github.com/user-attachments/assets/de375e4c-ef67-4dcc-a045-4037d9489191">
This commit is contained in:
parent
5764a95912
commit
c294a01c4b
8 changed files with 331 additions and 67 deletions
|
@ -0,0 +1,157 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from torchtune import training
|
||||
from torchtune.models import convert_weights
|
||||
from torchtune.training.checkpointing._utils import ModelType, safe_torch_load
|
||||
from torchtune.utils._logging import get_logger
|
||||
|
||||
logger = get_logger("DEBUG")
|
||||
|
||||
|
||||
class TorchtuneCheckpointer:
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
training_algorithm: str,
|
||||
checkpoint_dir: str,
|
||||
checkpoint_files: List[str],
|
||||
output_dir: str,
|
||||
model_type: str,
|
||||
) -> None:
|
||||
# Fail fast if ``checkpoint_files`` is invalid
|
||||
# TODO: support loading more than one file
|
||||
if len(checkpoint_files) != 1:
|
||||
raise ValueError(
|
||||
"Currently we only support reading from a single torchtune checkpoint file. "
|
||||
f"Got {len(checkpoint_files)} files instead."
|
||||
)
|
||||
self._checkpoint_file = checkpoint_files[0]
|
||||
self._model_id = model_id
|
||||
self._training_algorithm = training_algorithm
|
||||
self._checkpoint_dir = Path(checkpoint_dir)
|
||||
self._model_type = ModelType[model_type]
|
||||
self._output_dir = output_dir
|
||||
# get ckpt paths
|
||||
self._checkpoint_path = Path.joinpath(
|
||||
self._checkpoint_dir, self._checkpoint_file
|
||||
)
|
||||
|
||||
def load_checkpoint(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Load Meta checkpoint from file. Currently only loading from a single file is supported.
|
||||
"""
|
||||
state_dict: Dict[str:Any] = {}
|
||||
model_state_dict = safe_torch_load(self._checkpoint_path)
|
||||
if self._model_type == ModelType.LLAMA3_VISION:
|
||||
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||
llama3_vision_meta_to_tune,
|
||||
)
|
||||
|
||||
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(
|
||||
model_state_dict
|
||||
)
|
||||
else:
|
||||
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(
|
||||
model_state_dict
|
||||
)
|
||||
|
||||
# llama3_2 has tied weights, so we need to remove the output.weight key
|
||||
if self._model_type == ModelType.LLAMA3_2:
|
||||
logger.info(
|
||||
"Identified model_type = Llama3_2. Ignoring output.weight in"
|
||||
" checkpoint in favor of the tok_embedding.weight"
|
||||
" tied weights."
|
||||
)
|
||||
state_dict[training.MODEL_KEY].pop("output.weight")
|
||||
|
||||
return state_dict
|
||||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
epoch: int,
|
||||
adapter_only: bool = False,
|
||||
) -> str:
|
||||
model_file_path = (
|
||||
Path(self._output_dir)
|
||||
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||
)
|
||||
|
||||
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# copy the related files for inference
|
||||
shutil.copy(
|
||||
Path.joinpath(self._checkpoint_dir, "params.json"),
|
||||
Path.joinpath(model_file_path, "params.json"),
|
||||
)
|
||||
shutil.copy(
|
||||
Path.joinpath(self._checkpoint_dir, "tokenizer.model"),
|
||||
Path.joinpath(model_file_path, "tokenizer.model"),
|
||||
)
|
||||
shutil.copy(
|
||||
Path.joinpath(self._checkpoint_dir, "orig_params.json"),
|
||||
Path.joinpath(model_file_path, "orig_params.json"),
|
||||
)
|
||||
|
||||
if not adapter_only:
|
||||
model_state_dict = state_dict[training.MODEL_KEY]
|
||||
if self._model_type == ModelType.LLAMA3_VISION:
|
||||
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||
llama3_vision_tune_to_meta,
|
||||
)
|
||||
|
||||
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
|
||||
model_state_dict
|
||||
)
|
||||
else:
|
||||
# llama3_2 has tied weights, so we need to add the output.weight key
|
||||
if (
|
||||
self._model_type == ModelType.LLAMA3_2
|
||||
and "output.weight" not in model_state_dict
|
||||
):
|
||||
model_state_dict["output.weight"] = model_state_dict[
|
||||
"tok_embeddings.weight"
|
||||
]
|
||||
|
||||
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
|
||||
model_state_dict
|
||||
)
|
||||
|
||||
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
|
||||
|
||||
torch.save(state_dict[training.MODEL_KEY], model_file_name)
|
||||
logger.info(
|
||||
"Model checkpoint of size "
|
||||
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
|
||||
f"saved to {model_file_name}"
|
||||
)
|
||||
|
||||
if training.ADAPTER_KEY in state_dict:
|
||||
adapter_file_path = model_file_path / "adapter"
|
||||
adapter_file_path.mkdir(parents=True, exist_ok=True)
|
||||
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
|
||||
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
|
||||
logger.info(
|
||||
"Adapter checkpoint of size "
|
||||
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
|
||||
f"saved to {adapter_file_name}"
|
||||
)
|
||||
|
||||
elif adapter_only:
|
||||
raise ValueError(
|
||||
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||
)
|
||||
|
||||
print("model_file_path", str(model_file_path))
|
||||
|
||||
return str(model_file_path)
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.common.type_system import * # noqa
|
||||
from llama_models.datatypes import Model
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
|
||||
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||
|
||||
|
||||
class ColumnName(Enum):
|
||||
instruction = "instruction"
|
||||
input = "input"
|
||||
output = "output"
|
||||
text = "text"
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
model_definition: Any
|
||||
tokenizer_type: Any
|
||||
checkpoint_type: str
|
||||
|
||||
|
||||
class DatasetSchema(BaseModel):
|
||||
alpaca: List[Dict[str, ParamType]]
|
||||
|
||||
|
||||
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||
"Llama3.2-3B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_2_3b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3_2",
|
||||
),
|
||||
"Llama-3-8B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_8b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
EXPECTED_DATASET_SCHEMA = DatasetSchema(
|
||||
alpaca=[
|
||||
{
|
||||
ColumnName.instruction.value: StringType(),
|
||||
ColumnName.input.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
ColumnName.text.value: StringType(),
|
||||
},
|
||||
{
|
||||
ColumnName.instruction.value: StringType(),
|
||||
ColumnName.input.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
},
|
||||
{
|
||||
ColumnName.instruction.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||
|
||||
|
||||
def _validate_model_id(model_id: str) -> Model:
|
||||
model = resolve_model(model_id)
|
||||
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
||||
raise ValueError(f"Model {model_id} is not supported.")
|
||||
return model
|
||||
|
||||
|
||||
async def get_model_definition(
|
||||
model_id: str,
|
||||
) -> BuildLoraModelCallable:
|
||||
model = _validate_model_id(model_id)
|
||||
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||
if not hasattr(model_config, "model_definition"):
|
||||
raise ValueError(f"Model {model_id} does not have model definition.")
|
||||
return model_config.model_definition
|
||||
|
||||
|
||||
async def get_tokenizer_type(
|
||||
model_id: str,
|
||||
) -> BuildTokenizerCallable:
|
||||
model = _validate_model_id(model_id)
|
||||
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||
if not hasattr(model_config, "tokenizer_type"):
|
||||
raise ValueError(f"Model {model_id} does not have tokenizer_type.")
|
||||
return model_config.tokenizer_type
|
||||
|
||||
|
||||
async def get_checkpointer_model_type(
|
||||
model_id: str,
|
||||
) -> str:
|
||||
"""
|
||||
checkpointer model type is used in checkpointer for some special treatment on some specific model types
|
||||
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
|
||||
"""
|
||||
model = _validate_model_id(model_id)
|
||||
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||
if not hasattr(model_config, "checkpoint_type"):
|
||||
raise ValueError(f"Model {model_id} does not have checkpoint_type.")
|
||||
return model_config.checkpoint_type
|
||||
|
||||
|
||||
async def validate_input_dataset_schema(
|
||||
datasets_api: Datasets,
|
||||
dataset_id: str,
|
||||
dataset_type: str,
|
||||
) -> None:
|
||||
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||
|
||||
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
||||
|
||||
if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue