mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
Avoid using nearly double the memory needed (#30)
This commit is contained in:
parent
b311dcd143
commit
00f0e6d92b
1 changed files with 9 additions and 22 deletions
|
@ -106,8 +106,6 @@ class Llama:
|
||||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||||
params = json.loads(f.read())
|
params = json.loads(f.read())
|
||||||
|
|
||||||
# TODO(ashwin): this block is so we can load internal checkpoints without additional
|
|
||||||
# fuss. the final code should _not_ have this blurb
|
|
||||||
if "model" in params:
|
if "model" in params:
|
||||||
params = params["model"]
|
params = params["model"]
|
||||||
|
|
||||||
|
@ -130,34 +128,23 @@ class Llama:
|
||||||
)
|
)
|
||||||
|
|
||||||
if fp8:
|
if fp8:
|
||||||
|
from .quantization.loader import convert_to_quantized_model
|
||||||
|
|
||||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||||
# unexpected (fp32, e.g.) datatype
|
# unexpected (fp32, e.g.) datatype
|
||||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||||
|
|
||||||
model = Transformer(model_args)
|
model = Transformer(model_args)
|
||||||
|
|
||||||
if fp8:
|
|
||||||
# load on CPU first since if we are doing fp8, we probably don't
|
|
||||||
# have enough memory on GPU for bf16
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
model = convert_to_quantized_model(model, config)
|
||||||
|
else:
|
||||||
if torch.cuda.is_bf16_supported():
|
if torch.cuda.is_bf16_supported():
|
||||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||||
else:
|
else:
|
||||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||||
|
model = Transformer(model_args)
|
||||||
if not fp8:
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
if config.quantization:
|
|
||||||
from .quantization.loader import convert_to_quantized_model
|
|
||||||
|
|
||||||
model = convert_to_quantized_model(model, config)
|
|
||||||
else:
|
|
||||||
model = model.to("cuda")
|
|
||||||
|
|
||||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||||
|
|
||||||
return Llama(model, tokenizer, model_args)
|
return Llama(model, tokenizer, model_args)
|
||||||
|
|
||||||
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
|
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue