Fix post training apis broken by torchtune release (#674)

There is a torchtune release this morning
https://github.com/pytorch/torchtune/releases/tag/v0.5.0 and breaks post
training apis

## test 
spinning up server and the post training works again after the fix 
<img width="1314" alt="Screenshot 2024-12-20 at 4 08 54 PM"
src="https://github.com/user-attachments/assets/dfae724d-ebf0-4846-9715-096efa060cee"
/>


## Note
We need to think hard of how to avoid this happen again and have a fast
follow up on this after holidays
This commit is contained in:
Botao Chen 2024-12-20 16:12:02 -08:00 committed by GitHub
parent 06cb0c837e
commit bae197c37e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -43,7 +43,6 @@ from torchtune.modules.peft import (
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
)
@ -281,7 +280,6 @@ class LoraFinetuningSingleDevice:
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude()
load_dora_magnitudes(model)
if lora_weights_state_dict:
lora_missing, lora_unexpected = model.load_state_dict(
lora_weights_state_dict, strict=False