temp commit

This commit is contained in:
Botao Chen 2025-01-03 14:30:08 -08:00
parent 346a6c658d
commit 82d575811c

View file

@ -448,6 +448,10 @@ class LoraFinetuningSingleDevice:
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
# Shape [b, s], needed for the loss not the model
# print("tokens", batch["tokens"])
torch.save(batch["tokens"], "/home/markchen1015/new_alpaca_tokens.pth")
# print("labels", batch["labels"])
torch.save(batch["labels"], "/home/markchen1015/new_alpaca_labels.pth")
labels = batch.pop("labels")
# run model
with self.activations_handling_ctx: