address comments

This commit is contained in:
Botao Chen 2024-12-12 14:05:40 -08:00
parent d7d19dc0e5
commit 3378c100f6

View file

@ -16,6 +16,7 @@ from typing import Any, Callable, Dict, List
import torch import torch
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.common.type_system import * # noqa from llama_stack.apis.common.type_system import * # noqa
from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.type_system import ParamType
@ -31,18 +32,29 @@ class ColumnName(Enum):
text = "text" text = "text"
MODEL_CONFIGS: Dict[str, Dict[str, Any]] = { class ModelConfig(BaseModel):
"Llama3.2-3B-Instruct": { model_definition: Any
"model_definition": lora_llama3_2_3b, tokenizer_type: Any
"tokenizer_type": llama3_tokenizer, checkpoint_type: str
"checkpoint_type": "LLAMA3_2",
},
"Llama-3-8B-Instruct": { class ModelConfigs(BaseModel):
"model_definition": lora_llama3_8b, Llama3_2_3B_Instruct: ModelConfig
"tokenizer_type": llama3_tokenizer, Llama_3_8B_Instruct: ModelConfig
"checkpoint_type": "LLAMA3",
},
} MODEL_CONFIGS = ModelConfigs(
Llama3_2_3B_Instruct=ModelConfig(
model_definition=lora_llama3_2_3b,
tokenizer_type=llama3_tokenizer,
checkpoint_type="LLAMA3_2",
),
Llama_3_8B_Instruct=ModelConfig(
model_definition=lora_llama3_8b,
tokenizer_type=llama3_tokenizer,
checkpoint_type="LLAMA3",
),
)
EXPECTED_DATASET_SCHEMA: Dict[str, List[Dict[str, ParamType]]] = { EXPECTED_DATASET_SCHEMA: Dict[str, List[Dict[str, ParamType]]] = {
"alpaca": [ "alpaca": [
@ -68,20 +80,38 @@ BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer] BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
def _modify_model_id(model_id: str) -> str:
return model_id.replace("-", "_").replace(".", "_")
def _validate_model_id(model_id: str) -> Model:
model = resolve_model(model_id)
modified_model_id = _modify_model_id(model.core_model_id.value)
if model is None or not hasattr(MODEL_CONFIGS, modified_model_id):
raise ValueError(f"Model {model_id} is not supported.")
return model
async def get_model_definition( async def get_model_definition(
model_id: str, model_id: str,
) -> BuildLoraModelCallable: ) -> BuildLoraModelCallable:
model = resolve_model(model_id) model = _validate_model_id(model_id)
if model is None or model.core_model_id.value not in MODEL_CONFIGS: modified_model_id = _modify_model_id(model.core_model_id.value)
raise ValueError(f"Model {model_id} is not supported.") model_config = getattr(MODEL_CONFIGS, modified_model_id)
return MODEL_CONFIGS[model.core_model_id.value]["model_definition"] if not hasattr(model_config, "model_definition"):
raise ValueError(f"Model {model_id} does not have model definition.")
return model_config.model_definition
async def get_tokenizer_type( async def get_tokenizer_type(
model_id: str, model_id: str,
) -> BuildTokenizerCallable: ) -> BuildTokenizerCallable:
model = resolve_model(model_id) model = _validate_model_id(model_id)
return MODEL_CONFIGS[model.core_model_id.value]["tokenizer_type"] modified_model_id = _modify_model_id(model.core_model_id.value)
model_config = getattr(MODEL_CONFIGS, modified_model_id)
if not hasattr(model_config, "tokenizer_type"):
raise ValueError(f"Model {model_id} does not have tokenizer_type.")
return model_config.tokenizer_type
async def get_checkpointer_model_type( async def get_checkpointer_model_type(
@ -91,8 +121,12 @@ async def get_checkpointer_model_type(
checkpointer model type is used in checkpointer for some special treatment on some specific model types checkpointer model type is used in checkpointer for some special treatment on some specific model types
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041) For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
""" """
model = resolve_model(model_id) model = _validate_model_id(model_id)
return MODEL_CONFIGS[model.core_model_id.value]["checkpoint_type"] modified_model_id = _modify_model_id(model.core_model_id.value)
model_config = getattr(MODEL_CONFIGS, modified_model_id)
if not hasattr(model_config, "checkpoint_type"):
raise ValueError(f"Model {model_id} does not have checkpoint_type.")
return model_config.checkpoint_type
async def validate_input_dataset_schema( async def validate_input_dataset_schema(