From 45e210fd0c43ee76a93f21275575f3cbd83a70f6 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 9 Apr 2025 01:09:16 -0700 Subject: [PATCH] fix: llama3 bf16 model load --- llama_stack/models/llama/llama3/generation.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/llama_stack/models/llama/llama3/generation.py b/llama_stack/models/llama/llama3/generation.py index ee99a07ba..8c6aa242b 100644 --- a/llama_stack/models/llama/llama3/generation.py +++ b/llama_stack/models/llama/llama3/generation.py @@ -119,17 +119,16 @@ class Llama3: torch.set_default_device(device) else: print(f"Setting default device to {device}") - torch.set_default_device(device) if device.type == "cuda": if torch.cuda.is_bf16_supported(): - torch.set_default_dtype(torch.bfloat16) + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) else: - torch.set_default_dtype(torch.half) + torch.set_default_tensor_type(torch.cuda.Float16Tensor) elif device.type == "xpu": if torch.xpu.is_bf16_supported(): - torch.set_default_dtype(torch.bfloat16) + torch.set_default_tensor_type(torch.xpu.BFloat16Tensor) else: - torch.set_default_dtype(torch.half) + torch.set_default_tensor_type(torch.xpu.Float16Tensor) model = build_model() print("Loading state dict...")