Changes from the main repo

This commit is contained in:
Ashwin Bharambe 2024-07-19 16:11:17 -07:00
parent 9c9b834c0f
commit 7d2c0b14b8
8 changed files with 24 additions and 9 deletions

View file

@ -93,7 +93,7 @@ class Llama:
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu")
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())