forked from phoenix-oss/llama-stack-mirror
Avoid using nearly double the memory needed (#30)
This commit is contained in:
parent
b311dcd143
commit
00f0e6d92b
1 changed files with 9 additions and 22 deletions
|
@ -106,8 +106,6 @@ class Llama:
|
|||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
# TODO(ashwin): this block is so we can load internal checkpoints without additional
|
||||
# fuss. the final code should _not_ have this blurb
|
||||
if "model" in params:
|
||||
params = params["model"]
|
||||
|
||||
|
@ -130,34 +128,23 @@ class Llama:
|
|||
)
|
||||
|
||||
if fp8:
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||
# unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
|
||||
model = Transformer(model_args)
|
||||
|
||||
if fp8:
|
||||
# load on CPU first since if we are doing fp8, we probably don't
|
||||
# have enough memory on GPU for bf16
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
model = convert_to_quantized_model(model, config)
|
||||
else:
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
|
||||
if not fp8:
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if config.quantization:
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
||||
model = convert_to_quantized_model(model, config)
|
||||
else:
|
||||
model = model.to("cuda")
|
||||
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
return Llama(model, tokenizer, model_args)
|
||||
|
||||
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue