[2/n][torchtune integration] implement job management and return training artifacts (#593)

### Context 
In this PR, we 
- Implement the post training job management and get training artifacts
apis
  - get_training_jobs
  - get_training_job_status
  - get_training_job_artifacts
- get_training_job_logstream is deleted since the trace can be directly
accessed by UI with Jaeger
https://llama-stack.readthedocs.io/en/latest/building_applications/telemetry.html#jaeger-to-visualize-traces
- Refactor the post training and training types definition to make them
more intuitive.
- Rewrite the checkpointer to make it compatible with llama-stack file
system and can be recognized during inference


### Test
Unit test
`pytest llama_stack/providers/tests/post_training/test_post_training.py
-m "torchtune_post_training_huggingface_datasetio" -v -s --tb=short
--disable-warnings`

<img width="1506" alt="Screenshot 2024-12-10 at 4 06 17 PM"
src="https://github.com/user-attachments/assets/16225029-bdb7-48c4-9d13-e580cc769c0a">


e2e test with client side call

<img width="888" alt="Screenshot 2024-12-10 at 4 09 44 PM"
src="https://github.com/user-attachments/assets/de375e4c-ef67-4dcc-a045-4037d9489191">
This commit is contained in:
Botao Chen 2024-12-13 15:00:04 -08:00 committed by GitHub
parent 5764a95912
commit c294a01c4b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 331 additions and 67 deletions

View file

@ -0,0 +1,157 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List
import torch
from torchtune import training
from torchtune.models import convert_weights
from torchtune.training.checkpointing._utils import ModelType, safe_torch_load
from torchtune.utils._logging import get_logger
logger = get_logger("DEBUG")
class TorchtuneCheckpointer:
def __init__(
self,
model_id: str,
training_algorithm: str,
checkpoint_dir: str,
checkpoint_files: List[str],
output_dir: str,
model_type: str,
) -> None:
# Fail fast if ``checkpoint_files`` is invalid
# TODO: support loading more than one file
if len(checkpoint_files) != 1:
raise ValueError(
"Currently we only support reading from a single torchtune checkpoint file. "
f"Got {len(checkpoint_files)} files instead."
)
self._checkpoint_file = checkpoint_files[0]
self._model_id = model_id
self._training_algorithm = training_algorithm
self._checkpoint_dir = Path(checkpoint_dir)
self._model_type = ModelType[model_type]
self._output_dir = output_dir
# get ckpt paths
self._checkpoint_path = Path.joinpath(
self._checkpoint_dir, self._checkpoint_file
)
def load_checkpoint(self) -> Dict[str, Any]:
"""
Load Meta checkpoint from file. Currently only loading from a single file is supported.
"""
state_dict: Dict[str:Any] = {}
model_state_dict = safe_torch_load(self._checkpoint_path)
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_meta_to_tune,
)
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(
model_state_dict
)
else:
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(
model_state_dict
)
# llama3_2 has tied weights, so we need to remove the output.weight key
if self._model_type == ModelType.LLAMA3_2:
logger.info(
"Identified model_type = Llama3_2. Ignoring output.weight in"
" checkpoint in favor of the tok_embedding.weight"
" tied weights."
)
state_dict[training.MODEL_KEY].pop("output.weight")
return state_dict
def save_checkpoint(
self,
state_dict: Dict[str, Any],
epoch: int,
adapter_only: bool = False,
) -> str:
model_file_path = (
Path(self._output_dir)
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
)
model_file_path.mkdir(parents=True, exist_ok=True)
# copy the related files for inference
shutil.copy(
Path.joinpath(self._checkpoint_dir, "params.json"),
Path.joinpath(model_file_path, "params.json"),
)
shutil.copy(
Path.joinpath(self._checkpoint_dir, "tokenizer.model"),
Path.joinpath(model_file_path, "tokenizer.model"),
)
shutil.copy(
Path.joinpath(self._checkpoint_dir, "orig_params.json"),
Path.joinpath(model_file_path, "orig_params.json"),
)
if not adapter_only:
model_state_dict = state_dict[training.MODEL_KEY]
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_tune_to_meta,
)
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
model_state_dict
)
else:
# llama3_2 has tied weights, so we need to add the output.weight key
if (
self._model_type == ModelType.LLAMA3_2
and "output.weight" not in model_state_dict
):
model_state_dict["output.weight"] = model_state_dict[
"tok_embeddings.weight"
]
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
model_state_dict
)
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
torch.save(state_dict[training.MODEL_KEY], model_file_name)
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
f"saved to {model_file_name}"
)
if training.ADAPTER_KEY in state_dict:
adapter_file_path = model_file_path / "adapter"
adapter_file_path.mkdir(parents=True, exist_ok=True)
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
f"saved to {adapter_file_name}"
)
elif adapter_only:
raise ValueError(
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)
print("model_file_path", str(model_file_path))
return str(model_file_path)

View file

@ -0,0 +1,139 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Callable, Dict, List
import torch
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.common.type_system import * # noqa
from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model
from llama_stack.apis.common.type_system import ParamType
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_2 import lora_llama3_2_3b
class ColumnName(Enum):
instruction = "instruction"
input = "input"
output = "output"
text = "text"
class ModelConfig(BaseModel):
model_definition: Any
tokenizer_type: Any
checkpoint_type: str
class DatasetSchema(BaseModel):
alpaca: List[Dict[str, ParamType]]
MODEL_CONFIGS: Dict[str, ModelConfig] = {
"Llama3.2-3B-Instruct": ModelConfig(
model_definition=lora_llama3_2_3b,
tokenizer_type=llama3_tokenizer,
checkpoint_type="LLAMA3_2",
),
"Llama-3-8B-Instruct": ModelConfig(
model_definition=lora_llama3_8b,
tokenizer_type=llama3_tokenizer,
checkpoint_type="LLAMA3",
),
}
EXPECTED_DATASET_SCHEMA = DatasetSchema(
alpaca=[
{
ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
ColumnName.text.value: StringType(),
},
{
ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
},
{
ColumnName.instruction.value: StringType(),
ColumnName.output.value: StringType(),
},
]
)
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
def _validate_model_id(model_id: str) -> Model:
model = resolve_model(model_id)
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
raise ValueError(f"Model {model_id} is not supported.")
return model
async def get_model_definition(
model_id: str,
) -> BuildLoraModelCallable:
model = _validate_model_id(model_id)
model_config = MODEL_CONFIGS[model.core_model_id.value]
if not hasattr(model_config, "model_definition"):
raise ValueError(f"Model {model_id} does not have model definition.")
return model_config.model_definition
async def get_tokenizer_type(
model_id: str,
) -> BuildTokenizerCallable:
model = _validate_model_id(model_id)
model_config = MODEL_CONFIGS[model.core_model_id.value]
if not hasattr(model_config, "tokenizer_type"):
raise ValueError(f"Model {model_id} does not have tokenizer_type.")
return model_config.tokenizer_type
async def get_checkpointer_model_type(
model_id: str,
) -> str:
"""
checkpointer model type is used in checkpointer for some special treatment on some specific model types
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
"""
model = _validate_model_id(model_id)
model_config = MODEL_CONFIGS[model.core_model_id.value]
if not hasattr(model_config, "checkpoint_type"):
raise ValueError(f"Model {model_id} does not have checkpoint_type.")
return model_config.checkpoint_type
async def validate_input_dataset_schema(
datasets_api: Datasets,
dataset_id: str,
dataset_type: str,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(f"Dataset type {dataset_type} is not supported.")
if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
)