fix: llama3 bf16 model load

This commit is contained in:
Ashwin Bharambe 2025-04-09 01:09:16 -07:00
parent e3d22d8de7
commit 45e210fd0c

View file

@ -119,17 +119,16 @@ class Llama3:
torch.set_default_device(device)
else:
print(f"Setting default device to {device}")
torch.set_default_device(device)
if device.type == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_dtype(torch.half)
torch.set_default_tensor_type(torch.cuda.Float16Tensor)
elif device.type == "xpu":
if torch.xpu.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
torch.set_default_tensor_type(torch.xpu.BFloat16Tensor)
else:
torch.set_default_dtype(torch.half)
torch.set_default_tensor_type(torch.xpu.Float16Tensor)
model = build_model()
print("Loading state dict...")