diff --git a/llama_stack/providers/inline/post_training/torchtune/utils.py b/llama_stack/providers/inline/post_training/torchtune/utils.py index 6cbee8766..56a3c2fa2 100644 --- a/llama_stack/providers/inline/post_training/torchtune/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/utils.py @@ -87,6 +87,10 @@ async def get_tokenizer_type( async def get_checkpointer_model_type( model_id: str, ) -> str: + """ + checkpointer model type is used in checkpointer for some special treatment on some specific model types + For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041) + """ model = resolve_model(model_id) return MODEL_CONFIGS[model.core_model_id.value]["checkpoint_type"]