diff --git a/llama_stack/models/llama/llama3/generation.py b/llama_stack/models/llama/llama3/generation.py index ee99a07ba..8c6aa242b 100644 --- a/llama_stack/models/llama/llama3/generation.py +++ b/llama_stack/models/llama/llama3/generation.py @@ -119,17 +119,16 @@ class Llama3: torch.set_default_device(device) else: print(f"Setting default device to {device}") - torch.set_default_device(device) if device.type == "cuda": if torch.cuda.is_bf16_supported(): - torch.set_default_dtype(torch.bfloat16) + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) else: - torch.set_default_dtype(torch.half) + torch.set_default_tensor_type(torch.cuda.Float16Tensor) elif device.type == "xpu": if torch.xpu.is_bf16_supported(): - torch.set_default_dtype(torch.bfloat16) + torch.set_default_tensor_type(torch.xpu.BFloat16Tensor) else: - torch.set_default_dtype(torch.half) + torch.set_default_tensor_type(torch.xpu.Float16Tensor) model = build_model() print("Loading state dict...")