From d55a8343ea6a238e84a24b228c8e0ff36a0c4aec Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Fri, 13 Dec 2024 12:55:21 -0800 Subject: [PATCH] merge --- .../post_training/torchtune/common/utils.py | 121 ++++++++++++--- .../post_training/torchtune/post_training.py | 3 +- .../recipes/lora_finetuning_single_device.py | 11 +- .../inline/post_training/torchtune/utils.py | 139 ------------------ 4 files changed, 111 insertions(+), 163 deletions(-) delete mode 100644 llama_stack/providers/inline/post_training/torchtune/utils.py diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index 93c7ef189..462cbc21e 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -10,49 +10,130 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Callable, Dict +from enum import Enum +from typing import Any, Callable, Dict, List import torch +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.common.type_system import * # noqa +from llama_models.datatypes import Model from llama_models.sku_list import resolve_model +from llama_stack.apis.common.type_system import ParamType from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3_2 import lora_llama3_2_3b -LORA_MODEL_TYPES: Dict[str, Any] = { - "Llama3.2-3B-Instruct": lora_llama3_2_3b, - "Llama-3-8B-Instruct": lora_llama3_8b, + +class ColumnName(Enum): + instruction = "instruction" + input = "input" + output = "output" + text = "text" + + +class ModelConfig(BaseModel): + model_definition: Any + tokenizer_type: Any + checkpoint_type: str + + +class DatasetSchema(BaseModel): + alpaca: List[Dict[str, ParamType]] + + +MODEL_CONFIGS: Dict[str, ModelConfig] = { + "Llama3.2-3B-Instruct": ModelConfig( + model_definition=lora_llama3_2_3b, + tokenizer_type=llama3_tokenizer, + checkpoint_type="LLAMA3_2", + ), + "Llama-3-8B-Instruct": ModelConfig( + model_definition=lora_llama3_8b, + tokenizer_type=llama3_tokenizer, + checkpoint_type="LLAMA3", + ), } -TOKENIZER_TYPES: Dict[str, Any] = { - "Llama3.2-3B-Instruct": llama3_tokenizer, - "Llama-3-8B-Instruct": llama3_tokenizer, -} -CHECKPOINT_MODEL_TYPES: Dict[str, str] = { - "Llama3.2-3B-Instruct": "LLAMA3_2", -} +EXPECTED_DATASET_SCHEMA = DatasetSchema( + alpaca=[ + { + ColumnName.instruction.value: StringType(), + ColumnName.input.value: StringType(), + ColumnName.output.value: StringType(), + ColumnName.text.value: StringType(), + }, + { + ColumnName.instruction.value: StringType(), + ColumnName.input.value: StringType(), + ColumnName.output.value: StringType(), + }, + { + ColumnName.instruction.value: StringType(), + ColumnName.output.value: StringType(), + }, + ] +) BuildLoraModelCallable = Callable[..., torch.nn.Module] BuildTokenizerCallable = Callable[..., Llama3Tokenizer] -def get_model_type( +def _validate_model_id(model_id: str) -> Model: + model = resolve_model(model_id) + if model is None or model.core_model_id.value not in MODEL_CONFIGS: + raise ValueError(f"Model {model_id} is not supported.") + return model + + +async def get_model_definition( model_id: str, ) -> BuildLoraModelCallable: - model = resolve_model(model_id) - return LORA_MODEL_TYPES[model.core_model_id.value] + model = _validate_model_id(model_id) + model_config = MODEL_CONFIGS[model.core_model_id.value] + if not hasattr(model_config, "model_definition"): + raise ValueError(f"Model {model_id} does not have model definition.") + return model_config.model_definition -def get_tokenizer_type( +async def get_tokenizer_type( model_id: str, ) -> BuildTokenizerCallable: - model = resolve_model(model_id) - return TOKENIZER_TYPES[model.core_model_id.value] + model = _validate_model_id(model_id) + model_config = MODEL_CONFIGS[model.core_model_id.value] + if not hasattr(model_config, "tokenizer_type"): + raise ValueError(f"Model {model_id} does not have tokenizer_type.") + return model_config.tokenizer_type -def get_checkpointer_model_type( +async def get_checkpointer_model_type( model_id: str, ) -> str: - model = resolve_model(model_id) - return CHECKPOINT_MODEL_TYPES[model.core_model_id.value] + """ + checkpointer model type is used in checkpointer for some special treatment on some specific model types + For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041) + """ + model = _validate_model_id(model_id) + model_config = MODEL_CONFIGS[model.core_model_id.value] + if not hasattr(model_config, "checkpoint_type"): + raise ValueError(f"Model {model_id} does not have checkpoint_type.") + return model_config.checkpoint_type + + +async def validate_input_dataset_schema( + datasets_api: Datasets, + dataset_id: str, + dataset_type: str, +) -> None: + dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id) + if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: + raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") + + if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type): + raise ValueError(f"Dataset type {dataset_type} is not supported.") + + if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type): + raise ValueError( + f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}" + ) diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index 808d87045..4306752e1 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -37,7 +37,7 @@ class TorchtunePostTrainingImpl: logger_config: Dict[str, Any], model: str, checkpoint_dir: Optional[str], - algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]], + algorithm_config: Optional[AlgorithmConfig], ) -> PostTrainingJob: if job_uuid in self.jobs_list: raise ValueError(f"Job {job_uuid} already exists") @@ -63,6 +63,7 @@ class TorchtunePostTrainingImpl: checkpoint_dir, algorithm_config, self.datasetio_api, + self.datasets_api, ) job_status_response.status = JobStatus.in_progress diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 958150c91..0714046bf 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -74,12 +74,16 @@ class LoraFinetuningSingleDevice: logger_config: Dict[str, Any], model: str, checkpoint_dir: Optional[str], - algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]], + algorithm_config: Optional[AlgorithmConfig], datasetio_api: DatasetIO, datasets_api: Datasets, ) -> None: self.job_uuid = job_uuid self.training_config = training_config + if not isinstance(algorithm_config, LoraFinetuningConfig): + raise ValueError( + "You need to speicifc LoraFinetuningConfig for LoRA finetuning" + ) self.algorithm_config = algorithm_config self._device = torchtune_utils.get_device(device="cuda") self._dtype = training.get_dtype(training_config.dtype, device=self._device) @@ -134,6 +138,7 @@ class LoraFinetuningSingleDevice: ) self.datasetio_api = datasetio_api + self.datasets_api = datasets_api async def load_checkpoint(self): def get_checkpoint_files(checkpoint_dir: str) -> List[str]: @@ -152,7 +157,7 @@ class LoraFinetuningSingleDevice: checkpoint_dir=self.checkpoint_dir, checkpoint_files=get_checkpoint_files(self.checkpoint_dir), output_dir=self._output_dir, - model_type=utils.get_checkpointer_model_type(self.model_id), + model_type=await utils.get_checkpointer_model_type(self.model_id), ) checkpoint_dict = self._checkpointer.load_checkpoint() return checkpoint_dict @@ -297,7 +302,7 @@ class LoraFinetuningSingleDevice: self, ) -> Llama3Tokenizer: tokenizer_path = self.checkpoint_dir + "/tokenizer.model" - tokenizer_type = utils.get_tokenizer_type(self.model_id) + tokenizer_type = await utils.get_tokenizer_type(self.model_id) return tokenizer_type(path=tokenizer_path) async def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer: diff --git a/llama_stack/providers/inline/post_training/torchtune/utils.py b/llama_stack/providers/inline/post_training/torchtune/utils.py deleted file mode 100644 index 462cbc21e..000000000 --- a/llama_stack/providers/inline/post_training/torchtune/utils.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -# Copyright (c) Meta Platforms, IAny, nc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import Enum -from typing import Any, Callable, Dict, List - -import torch -from llama_stack.apis.datasets import Datasets -from llama_stack.apis.common.type_system import * # noqa -from llama_models.datatypes import Model -from llama_models.sku_list import resolve_model -from llama_stack.apis.common.type_system import ParamType - -from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b -from torchtune.models.llama3._tokenizer import Llama3Tokenizer -from torchtune.models.llama3_2 import lora_llama3_2_3b - - -class ColumnName(Enum): - instruction = "instruction" - input = "input" - output = "output" - text = "text" - - -class ModelConfig(BaseModel): - model_definition: Any - tokenizer_type: Any - checkpoint_type: str - - -class DatasetSchema(BaseModel): - alpaca: List[Dict[str, ParamType]] - - -MODEL_CONFIGS: Dict[str, ModelConfig] = { - "Llama3.2-3B-Instruct": ModelConfig( - model_definition=lora_llama3_2_3b, - tokenizer_type=llama3_tokenizer, - checkpoint_type="LLAMA3_2", - ), - "Llama-3-8B-Instruct": ModelConfig( - model_definition=lora_llama3_8b, - tokenizer_type=llama3_tokenizer, - checkpoint_type="LLAMA3", - ), -} - - -EXPECTED_DATASET_SCHEMA = DatasetSchema( - alpaca=[ - { - ColumnName.instruction.value: StringType(), - ColumnName.input.value: StringType(), - ColumnName.output.value: StringType(), - ColumnName.text.value: StringType(), - }, - { - ColumnName.instruction.value: StringType(), - ColumnName.input.value: StringType(), - ColumnName.output.value: StringType(), - }, - { - ColumnName.instruction.value: StringType(), - ColumnName.output.value: StringType(), - }, - ] -) - -BuildLoraModelCallable = Callable[..., torch.nn.Module] -BuildTokenizerCallable = Callable[..., Llama3Tokenizer] - - -def _validate_model_id(model_id: str) -> Model: - model = resolve_model(model_id) - if model is None or model.core_model_id.value not in MODEL_CONFIGS: - raise ValueError(f"Model {model_id} is not supported.") - return model - - -async def get_model_definition( - model_id: str, -) -> BuildLoraModelCallable: - model = _validate_model_id(model_id) - model_config = MODEL_CONFIGS[model.core_model_id.value] - if not hasattr(model_config, "model_definition"): - raise ValueError(f"Model {model_id} does not have model definition.") - return model_config.model_definition - - -async def get_tokenizer_type( - model_id: str, -) -> BuildTokenizerCallable: - model = _validate_model_id(model_id) - model_config = MODEL_CONFIGS[model.core_model_id.value] - if not hasattr(model_config, "tokenizer_type"): - raise ValueError(f"Model {model_id} does not have tokenizer_type.") - return model_config.tokenizer_type - - -async def get_checkpointer_model_type( - model_id: str, -) -> str: - """ - checkpointer model type is used in checkpointer for some special treatment on some specific model types - For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041) - """ - model = _validate_model_id(model_id) - model_config = MODEL_CONFIGS[model.core_model_id.value] - if not hasattr(model_config, "checkpoint_type"): - raise ValueError(f"Model {model_id} does not have checkpoint_type.") - return model_config.checkpoint_type - - -async def validate_input_dataset_schema( - datasets_api: Datasets, - dataset_id: str, - dataset_type: str, -) -> None: - dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id) - if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: - raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") - - if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type): - raise ValueError(f"Dataset type {dataset_type} is not supported.") - - if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type): - raise ValueError( - f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}" - )