mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
make inference server load checkpoints for fp8 inference
- introduce quantization related args for inference config - also kill GeneratorArgs
This commit is contained in:
parent
7d2c0b14b8
commit
ad62e2e1f3
10 changed files with 249 additions and 155 deletions
|
@ -7,7 +7,7 @@ import sys
|
|||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generator, List, Optional, TypedDict
|
||||
from typing import Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -23,6 +23,9 @@ from models.llama3_1.api.model import Transformer
|
|||
from models.llama3_1.api.tokenizer import Tokenizer
|
||||
from termcolor import cprint
|
||||
|
||||
from .api.config import CheckpointType, InlineImplConfig
|
||||
from .api.datatypes import QuantizationType
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenResult:
|
||||
|
@ -31,69 +34,52 @@ class TokenResult:
|
|||
logprobs: Optional[List[float]] = None
|
||||
|
||||
|
||||
class CompletionPrediction(TypedDict, total=False):
|
||||
generation: str
|
||||
tokens: List[str] # not required
|
||||
logprobs: List[float] # not required
|
||||
|
||||
|
||||
class Llama:
|
||||
@staticmethod
|
||||
def build(
|
||||
ckpt_dir: str,
|
||||
tokenizer_path: str,
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
model_parallel_size: Optional[int] = None,
|
||||
seed: int = 1,
|
||||
) -> "Llama":
|
||||
def build(config: InlineImplConfig):
|
||||
"""
|
||||
Build a Llama instance by initializing and loading a model checkpoint.
|
||||
|
||||
Args:
|
||||
ckpt_dir (str): Path to the directory containing checkpoint files.
|
||||
tokenizer_path (str): Path to the tokenizer file.
|
||||
max_seq_len (int): Maximum sequence length for input text.
|
||||
max_batch_size (int): Maximum batch size for inference.
|
||||
model_parallel_size (Optional[int], optional): Number of model parallel processes.
|
||||
If not provided, it's determined from the environment. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Llama: An instance of the Llama class with the loaded model and tokenizer.
|
||||
|
||||
Raises:
|
||||
AssertionError: If there are no checkpoint files in the specified directory,
|
||||
or if the model parallel size does not match the number of checkpoint files.
|
||||
|
||||
Note:
|
||||
This method initializes the distributed process group, sets the device to CUDA,
|
||||
and loads the pre-trained model and tokenizer.
|
||||
"""
|
||||
checkpoint = config.checkpoint_config.checkpoint
|
||||
if checkpoint.checkpoint_type != CheckpointType.pytorch.value:
|
||||
raise NotImplementedError("HuggingFace checkpoints not supported yet")
|
||||
|
||||
if config.quantization and config.quantization.type == QuantizationType.fp8.value:
|
||||
from .quantization.loader import is_fbgemm_available
|
||||
|
||||
if not is_fbgemm_available():
|
||||
raise ImportError("fbgemm-gpu is required for FP8 quantization")
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
||||
model_parallel_size = checkpoint.model_parallel_size
|
||||
if not model_parallel_is_initialized():
|
||||
if model_parallel_size is None:
|
||||
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
initialize_model_parallel(model_parallel_size)
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# seed must be the same in all processes
|
||||
torch.manual_seed(seed)
|
||||
if config.torch_seed is not None:
|
||||
torch.manual_seed(config.torch_seed)
|
||||
|
||||
if local_rank > 0:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
start_time = time.time()
|
||||
ckpt_dir = checkpoint.checkpoint_dir
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
assert model_parallel_size == len(
|
||||
checkpoints
|
||||
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
|
@ -103,22 +89,34 @@ class Llama:
|
|||
params = params["model"]
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=config.max_seq_len,
|
||||
max_batch_size=config.max_batch_size,
|
||||
**params,
|
||||
)
|
||||
tokenizer = Tokenizer(model_path=tokenizer_path)
|
||||
tokenizer = Tokenizer(model_path=checkpoint.tokenizer_path)
|
||||
|
||||
assert (
|
||||
model_args.vocab_size == tokenizer.n_words
|
||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
if config.quantization:
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
||||
model = convert_to_quantized_model(model, config)
|
||||
else:
|
||||
model = model.to("cuda")
|
||||
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
return Llama(model, tokenizer, model_args)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue