This commit is contained in:
Botao Chen 2024-12-10 20:55:19 -08:00
parent 68ebf8a8da
commit d7d19dc0e5

View file

@ -87,6 +87,10 @@ async def get_tokenizer_type(
async def get_checkpointer_model_type(
model_id: str,
) -> str:
"""
checkpointer model type is used in checkpointer for some special treatment on some specific model types
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
"""
model = resolve_model(model_id)
return MODEL_CONFIGS[model.core_model_id.value]["checkpoint_type"]