From 29d0896ec8186a1d524eee1e9541e0bc6e018404 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Fri, 13 Dec 2024 10:38:53 -0800 Subject: [PATCH] address commit --- .../inline/post_training/torchtune/utils.py | 29 +++++-------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/llama_stack/providers/inline/post_training/torchtune/utils.py b/llama_stack/providers/inline/post_training/torchtune/utils.py index c923eeefd..462cbc21e 100644 --- a/llama_stack/providers/inline/post_training/torchtune/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/utils.py @@ -38,27 +38,22 @@ class ModelConfig(BaseModel): checkpoint_type: str -class ModelConfigs(BaseModel): - Llama3_2_3B_Instruct: ModelConfig - Llama_3_8B_Instruct: ModelConfig - - class DatasetSchema(BaseModel): alpaca: List[Dict[str, ParamType]] -MODEL_CONFIGS = ModelConfigs( - Llama3_2_3B_Instruct=ModelConfig( +MODEL_CONFIGS: Dict[str, ModelConfig] = { + "Llama3.2-3B-Instruct": ModelConfig( model_definition=lora_llama3_2_3b, tokenizer_type=llama3_tokenizer, checkpoint_type="LLAMA3_2", ), - Llama_3_8B_Instruct=ModelConfig( + "Llama-3-8B-Instruct": ModelConfig( model_definition=lora_llama3_8b, tokenizer_type=llama3_tokenizer, checkpoint_type="LLAMA3", ), -) +} EXPECTED_DATASET_SCHEMA = DatasetSchema( @@ -85,14 +80,9 @@ BuildLoraModelCallable = Callable[..., torch.nn.Module] BuildTokenizerCallable = Callable[..., Llama3Tokenizer] -def _modify_model_id(model_id: str) -> str: - return model_id.replace("-", "_").replace(".", "_") - - def _validate_model_id(model_id: str) -> Model: model = resolve_model(model_id) - modified_model_id = _modify_model_id(model.core_model_id.value) - if model is None or not hasattr(MODEL_CONFIGS, modified_model_id): + if model is None or model.core_model_id.value not in MODEL_CONFIGS: raise ValueError(f"Model {model_id} is not supported.") return model @@ -101,8 +91,7 @@ async def get_model_definition( model_id: str, ) -> BuildLoraModelCallable: model = _validate_model_id(model_id) - modified_model_id = _modify_model_id(model.core_model_id.value) - model_config = getattr(MODEL_CONFIGS, modified_model_id) + model_config = MODEL_CONFIGS[model.core_model_id.value] if not hasattr(model_config, "model_definition"): raise ValueError(f"Model {model_id} does not have model definition.") return model_config.model_definition @@ -112,8 +101,7 @@ async def get_tokenizer_type( model_id: str, ) -> BuildTokenizerCallable: model = _validate_model_id(model_id) - modified_model_id = _modify_model_id(model.core_model_id.value) - model_config = getattr(MODEL_CONFIGS, modified_model_id) + model_config = MODEL_CONFIGS[model.core_model_id.value] if not hasattr(model_config, "tokenizer_type"): raise ValueError(f"Model {model_id} does not have tokenizer_type.") return model_config.tokenizer_type @@ -127,8 +115,7 @@ async def get_checkpointer_model_type( For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041) """ model = _validate_model_id(model_id) - modified_model_id = _modify_model_id(model.core_model_id.value) - model_config = getattr(MODEL_CONFIGS, modified_model_id) + model_config = MODEL_CONFIGS[model.core_model_id.value] if not hasattr(model_config, "checkpoint_type"): raise ValueError(f"Model {model_id} does not have checkpoint_type.") return model_config.checkpoint_type