forked from phoenix-oss/llama-stack-mirror
fix: llama3 bf16 model load
This commit is contained in:
parent
e3d22d8de7
commit
45e210fd0c
1 changed files with 4 additions and 5 deletions
|
@ -119,17 +119,16 @@ class Llama3:
|
|||
torch.set_default_device(device)
|
||||
else:
|
||||
print(f"Setting default device to {device}")
|
||||
torch.set_default_device(device)
|
||||
if device.type == "cuda":
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_dtype(torch.half)
|
||||
torch.set_default_tensor_type(torch.cuda.Float16Tensor)
|
||||
elif device.type == "xpu":
|
||||
if torch.xpu.is_bf16_supported():
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_default_tensor_type(torch.xpu.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_dtype(torch.half)
|
||||
torch.set_default_tensor_type(torch.xpu.Float16Tensor)
|
||||
|
||||
model = build_model()
|
||||
print("Loading state dict...")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue