diff --git a/llama_stack/providers/inline/post_training/torchtune/utils.py b/llama_stack/providers/inline/post_training/torchtune/utils.py index 56a3c2fa2..6f7c18f72 100644 --- a/llama_stack/providers/inline/post_training/torchtune/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/utils.py @@ -16,6 +16,7 @@ from typing import Any, Callable, Dict, List import torch from llama_stack.apis.datasets import Datasets from llama_stack.apis.common.type_system import * # noqa +from llama_models.datatypes import Model from llama_models.sku_list import resolve_model from llama_stack.apis.common.type_system import ParamType @@ -31,18 +32,29 @@ class ColumnName(Enum): text = "text" -MODEL_CONFIGS: Dict[str, Dict[str, Any]] = { - "Llama3.2-3B-Instruct": { - "model_definition": lora_llama3_2_3b, - "tokenizer_type": llama3_tokenizer, - "checkpoint_type": "LLAMA3_2", - }, - "Llama-3-8B-Instruct": { - "model_definition": lora_llama3_8b, - "tokenizer_type": llama3_tokenizer, - "checkpoint_type": "LLAMA3", - }, -} +class ModelConfig(BaseModel): + model_definition: Any + tokenizer_type: Any + checkpoint_type: str + + +class ModelConfigs(BaseModel): + Llama3_2_3B_Instruct: ModelConfig + Llama_3_8B_Instruct: ModelConfig + + +MODEL_CONFIGS = ModelConfigs( + Llama3_2_3B_Instruct=ModelConfig( + model_definition=lora_llama3_2_3b, + tokenizer_type=llama3_tokenizer, + checkpoint_type="LLAMA3_2", + ), + Llama_3_8B_Instruct=ModelConfig( + model_definition=lora_llama3_8b, + tokenizer_type=llama3_tokenizer, + checkpoint_type="LLAMA3", + ), +) EXPECTED_DATASET_SCHEMA: Dict[str, List[Dict[str, ParamType]]] = { "alpaca": [ @@ -68,20 +80,38 @@ BuildLoraModelCallable = Callable[..., torch.nn.Module] BuildTokenizerCallable = Callable[..., Llama3Tokenizer] +def _modify_model_id(model_id: str) -> str: + return model_id.replace("-", "_").replace(".", "_") + + +def _validate_model_id(model_id: str) -> Model: + model = resolve_model(model_id) + modified_model_id = _modify_model_id(model.core_model_id.value) + if model is None or not hasattr(MODEL_CONFIGS, modified_model_id): + raise ValueError(f"Model {model_id} is not supported.") + return model + + async def get_model_definition( model_id: str, ) -> BuildLoraModelCallable: - model = resolve_model(model_id) - if model is None or model.core_model_id.value not in MODEL_CONFIGS: - raise ValueError(f"Model {model_id} is not supported.") - return MODEL_CONFIGS[model.core_model_id.value]["model_definition"] + model = _validate_model_id(model_id) + modified_model_id = _modify_model_id(model.core_model_id.value) + model_config = getattr(MODEL_CONFIGS, modified_model_id) + if not hasattr(model_config, "model_definition"): + raise ValueError(f"Model {model_id} does not have model definition.") + return model_config.model_definition async def get_tokenizer_type( model_id: str, ) -> BuildTokenizerCallable: - model = resolve_model(model_id) - return MODEL_CONFIGS[model.core_model_id.value]["tokenizer_type"] + model = _validate_model_id(model_id) + modified_model_id = _modify_model_id(model.core_model_id.value) + model_config = getattr(MODEL_CONFIGS, modified_model_id) + if not hasattr(model_config, "tokenizer_type"): + raise ValueError(f"Model {model_id} does not have tokenizer_type.") + return model_config.tokenizer_type async def get_checkpointer_model_type( @@ -91,8 +121,12 @@ async def get_checkpointer_model_type( checkpointer model type is used in checkpointer for some special treatment on some specific model types For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041) """ - model = resolve_model(model_id) - return MODEL_CONFIGS[model.core_model_id.value]["checkpoint_type"] + model = _validate_model_id(model_id) + modified_model_id = _modify_model_id(model.core_model_id.value) + model_config = getattr(MODEL_CONFIGS, modified_model_id) + if not hasattr(model_config, "checkpoint_type"): + raise ValueError(f"Model {model_id} does not have checkpoint_type.") + return model_config.checkpoint_type async def validate_input_dataset_schema(