diff --git a/llama_toolchain/inference/meta_reference/generation.py b/llama_toolchain/inference/meta_reference/generation.py index 23df2e287..dfbaf1a3e 100644 --- a/llama_toolchain/inference/meta_reference/generation.py +++ b/llama_toolchain/inference/meta_reference/generation.py @@ -106,8 +106,6 @@ class Llama: with open(Path(ckpt_dir) / "params.json", "r") as f: params = json.loads(f.read()) - # TODO(ashwin): this block is so we can load internal checkpoints without additional - # fuss. the final code should _not_ have this blurb if "model" in params: params = params["model"] @@ -130,34 +128,23 @@ class Llama: ) if fp8: + from .quantization.loader import convert_to_quantized_model + # load on CPU in bf16 so that fp8 conversion does not find an # unexpected (fp32, e.g.) datatype torch.set_default_tensor_type(torch.BFloat16Tensor) - - model = Transformer(model_args) - - if fp8: - # load on CPU first since if we are doing fp8, we probably don't - # have enough memory on GPU for bf16 + model = Transformer(model_args) model.load_state_dict(state_dict, strict=False) - - if torch.cuda.is_bf16_supported(): - torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) - else: - torch.set_default_tensor_type(torch.cuda.HalfTensor) - - if not fp8: - model.load_state_dict(state_dict, strict=False) - - if config.quantization: - from .quantization.loader import convert_to_quantized_model - model = convert_to_quantized_model(model, config) else: - model = model.to("cuda") + if torch.cuda.is_bf16_supported(): + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + else: + torch.set_default_tensor_type(torch.cuda.HalfTensor) + model = Transformer(model_args) + model.load_state_dict(state_dict, strict=False) print(f"Loaded in {time.time() - start_time:.2f} seconds") - return Llama(model, tokenizer, model_args) def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):