forked from phoenix-oss/llama-stack-mirror
### Context This is the 1st of series PRs that integrate torchtune with llama-stack as meta reference post-training implementation. For MVP, we will focus on single device LoRA SFT. Though this PR is still WIP, we want to get early feedback on the high level design of this skeleton while still working on several details ### Scope To limit the scope of this PR, we focus on the skeleton of the implementation. **What are included?** - refine the post-training SFT apis - skeleton of supervised_fine_tune implementation. We verified that we can call the supervised_fine_tune API successfully from llama stack client SDK (client side PR: https://github.com/meta-llama/llama-stack-client-python/pull/51) - a very basic single device LoRA training recipe based on torchtune core components - parity check with torchtune library and post training api unit test **What are not includes?** - implementation of other job management, get training artifacts apis (separate PR) - refactor the meta reference inference logic to support eval on finetuned model (separate PR) - several necessary functionality in the training recipe such as logging, validation etc (separate PR) - interop with telemetry for tracing and metrics logging, currently temporarily log to local disk (separate PR) ### Testing **e2e test** Although we haven't added detailed testing and numerical parity check with torchtune yet, we did a simple E2E test from client to server 1. setup server with` llama stack build --template experimental-post-training --image-type conda` and `llama stack run experimental-post-training ` 2. On client, run `llama-stack-client --endpoint http://devgpu018.nha2.facebook.com:5000 post_training supervised_fine_tune` 3. Training finishes successfully. On server side, get the finetune checkpoints under output dir. On client side, get the job uuid server <img width="1110" alt="Screenshot 2024-12-02 at 5 52 32 PM" src="https://github.com/user-attachments/assets/b548eb90-7a9b-4edc-a858-ee237cc4361d"> client <img width="807" alt="Screenshot 2024-12-02 at 5 52 37 PM" src="https://github.com/user-attachments/assets/1138ffa8-4698-40fa-b190-3d7b99646838"> **parity check** torchtune dataloader output and llama-stack post training dataloader output are same <img width="1116" alt="Screenshot 2024-12-04 at 8 18 46 PM" src="https://github.com/user-attachments/assets/5e295cdc-4c24-4ea6-82c0-ca96ef1bd6ee"> torchtune LoRA SFT and llama-stack post training LoRA SFT on alpaca dataset with llama3.2 3B instruct model are numerical match <img width="860" alt="Screenshot 2024-12-04 at 8 17 01 PM" src="https://github.com/user-attachments/assets/c05cf0a8-c674-4d2e-9f0a-c5d01b2dca99"> <img width="1049" alt="Screenshot 2024-12-04 at 8 17 06 PM" src="https://github.com/user-attachments/assets/b911d4e2-e7b1-41a9-b62c-d75529b6d443"> **unit test ** ![Uploading Screenshot 2024-12-09 at 1.35.10 PM.png…]()
139 lines
4.5 KiB
Python
139 lines
4.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from enum import Enum
|
|
from typing import Any, Callable, Dict, List
|
|
|
|
import torch
|
|
from llama_stack.apis.datasets import Datasets
|
|
from llama_stack.apis.common.type_system import * # noqa
|
|
from llama_models.datatypes import Model
|
|
from llama_models.sku_list import resolve_model
|
|
from llama_stack.apis.common.type_system import ParamType
|
|
|
|
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
|
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
|
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
|
|
|
|
|
class ColumnName(Enum):
|
|
instruction = "instruction"
|
|
input = "input"
|
|
output = "output"
|
|
text = "text"
|
|
|
|
|
|
class ModelConfig(BaseModel):
|
|
model_definition: Any
|
|
tokenizer_type: Any
|
|
checkpoint_type: str
|
|
|
|
|
|
class DatasetSchema(BaseModel):
|
|
alpaca: List[Dict[str, ParamType]]
|
|
|
|
|
|
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
|
"Llama3.2-3B-Instruct": ModelConfig(
|
|
model_definition=lora_llama3_2_3b,
|
|
tokenizer_type=llama3_tokenizer,
|
|
checkpoint_type="LLAMA3_2",
|
|
),
|
|
"Llama-3-8B-Instruct": ModelConfig(
|
|
model_definition=lora_llama3_8b,
|
|
tokenizer_type=llama3_tokenizer,
|
|
checkpoint_type="LLAMA3",
|
|
),
|
|
}
|
|
|
|
|
|
EXPECTED_DATASET_SCHEMA = DatasetSchema(
|
|
alpaca=[
|
|
{
|
|
ColumnName.instruction.value: StringType(),
|
|
ColumnName.input.value: StringType(),
|
|
ColumnName.output.value: StringType(),
|
|
ColumnName.text.value: StringType(),
|
|
},
|
|
{
|
|
ColumnName.instruction.value: StringType(),
|
|
ColumnName.input.value: StringType(),
|
|
ColumnName.output.value: StringType(),
|
|
},
|
|
{
|
|
ColumnName.instruction.value: StringType(),
|
|
ColumnName.output.value: StringType(),
|
|
},
|
|
]
|
|
)
|
|
|
|
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
|
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
|
|
|
|
|
def _validate_model_id(model_id: str) -> Model:
|
|
model = resolve_model(model_id)
|
|
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
|
raise ValueError(f"Model {model_id} is not supported.")
|
|
return model
|
|
|
|
|
|
async def get_model_definition(
|
|
model_id: str,
|
|
) -> BuildLoraModelCallable:
|
|
model = _validate_model_id(model_id)
|
|
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
|
if not hasattr(model_config, "model_definition"):
|
|
raise ValueError(f"Model {model_id} does not have model definition.")
|
|
return model_config.model_definition
|
|
|
|
|
|
async def get_tokenizer_type(
|
|
model_id: str,
|
|
) -> BuildTokenizerCallable:
|
|
model = _validate_model_id(model_id)
|
|
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
|
if not hasattr(model_config, "tokenizer_type"):
|
|
raise ValueError(f"Model {model_id} does not have tokenizer_type.")
|
|
return model_config.tokenizer_type
|
|
|
|
|
|
async def get_checkpointer_model_type(
|
|
model_id: str,
|
|
) -> str:
|
|
"""
|
|
checkpointer model type is used in checkpointer for some special treatment on some specific model types
|
|
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
|
|
"""
|
|
model = _validate_model_id(model_id)
|
|
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
|
if not hasattr(model_config, "checkpoint_type"):
|
|
raise ValueError(f"Model {model_id} does not have checkpoint_type.")
|
|
return model_config.checkpoint_type
|
|
|
|
|
|
async def validate_input_dataset_schema(
|
|
datasets_api: Datasets,
|
|
dataset_id: str,
|
|
dataset_type: str,
|
|
) -> None:
|
|
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
|
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
|
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
|
|
|
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
|
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
|
|
|
if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
|
raise ValueError(
|
|
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
|
|
)
|