Merge branch 'meta-llama:main' into main

This commit is contained in:
Chacksu 2024-11-21 15:47:54 -05:00 committed by GitHub
commit 19bc7e8942
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 244 additions and 173 deletions

View file

@ -8,6 +8,7 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import logging
import math
import os
import sys
@ -31,7 +32,6 @@ from llama_models.llama3.reference_impl.multimodal.model import (
)
from llama_models.sku_list import resolve_model
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403
@ -50,6 +50,8 @@ from .config import (
MetaReferenceQuantizedInferenceConfig,
)
log = logging.getLogger(__name__)
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
@ -185,7 +187,7 @@ class Llama:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args, llama_model)
def __init__(
@ -221,7 +223,7 @@ class Llama:
self.formatter.vision_token if t == 128256 else t
for t in model_input.tokens
]
cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
log.info("Input to model -> " + self.tokenizer.decode(input_tokens))
prompt_tokens = [model_input.tokens]
bsz = 1
@ -231,9 +233,7 @@ class Llama:
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red"
)
log.error(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}")
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import asyncio
import logging
from typing import AsyncGenerator, List
@ -25,6 +26,7 @@ from .config import MetaReferenceInferenceConfig
from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__)
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1)
@ -49,7 +51,7 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
# verify that the checkpoint actually is for this model lol
async def initialize(self) -> None:
print(f"Loading model `{self.model.descriptor()}`")
log.info(f"Loading model `{self.model.descriptor()}`")
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()

View file

@ -11,6 +11,7 @@
# the root directory of this source tree.
import json
import logging
import multiprocessing
import os
import tempfile
@ -37,6 +38,8 @@ from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .generation import TokenResult
log = logging.getLogger(__name__)
class ProcessingMessageName(str, Enum):
ready_request = "ready_request"
@ -183,16 +186,16 @@ def retrieve_requests(reply_socket_url: str):
group=get_model_parallel_group(),
)
if isinstance(updates[0], CancelSentinel):
print("quitting generation loop because request was cancelled")
log.info(
"quitting generation loop because request was cancelled"
)
break
if mp_rank_0():
send_obj(EndSentinel())
except Exception as e:
print(f"[debug] got exception {e}")
import traceback
log.exception("exception in generation loop")
traceback.print_exc()
if mp_rank_0():
send_obj(ExceptionResponse(error=str(e)))
@ -252,7 +255,7 @@ def worker_process_entrypoint(
except StopIteration:
break
print("[debug] worker process done")
log.info("[debug] worker process done")
def launch_dist_group(
@ -313,7 +316,7 @@ def start_model_parallel_process(
request_socket.send(encode_msg(ReadyRequest()))
response = request_socket.recv()
print("Loaded model...")
log.info("Loaded model...")
return request_socket, process
@ -361,7 +364,7 @@ class ModelParallelProcessGroup:
break
if isinstance(obj, ExceptionResponse):
print(f"[debug] got exception {obj.error}")
log.error(f"[debug] got exception {obj.error}")
raise Exception(obj.error)
if isinstance(obj, TaskResponse):

View file

@ -8,14 +8,20 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import collections
import logging
from typing import Optional, Type
log = logging.getLogger(__name__)
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
print("Using efficient FP8 operators in FBGEMM.")
log.info("Using efficient FP8 operators in FBGEMM.")
except ImportError:
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
log.error(
"No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt."
)
raise
import torch

View file

@ -7,6 +7,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import logging
import os
from typing import Any, Dict, List, Optional
@ -21,7 +22,6 @@ from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model
from termcolor import cprint
from torch import nn, Tensor
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
@ -30,6 +30,8 @@ from llama_stack.apis.inference import QuantizationType
from ..config import MetaReferenceQuantizedInferenceConfig
log = logging.getLogger(__name__)
def swiglu_wrapper(
self,
@ -60,7 +62,7 @@ def convert_to_fp8_quantized_model(
# Move weights to GPU with quantization
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
cprint("Loading fp8 scales...", "yellow")
log.info("Loading fp8 scales...")
fp8_scales_path = os.path.join(
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
@ -85,7 +87,7 @@ def convert_to_fp8_quantized_model(
fp8_activation_scale_ub,
)
else:
cprint("Quantizing fp8 weights from bf16...", "yellow")
log.info("Quantizing fp8 weights from bf16...")
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):

View file

@ -8,6 +8,7 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import logging
import os
import shutil
import sys
@ -22,12 +23,18 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel,
model_parallel_is_initialized,
)
from fp8.fp8_impls import FfnQuantizeMode, quantize_fp8
from llama.model import ModelArgs, Transformer, TransformerBlock
from llama.tokenizer import Tokenizer
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from torch.nn.parameter import Parameter
from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impls import (
quantize_fp8,
)
log = logging.getLogger(__name__)
def main(
ckpt_dir: str,
@ -36,7 +43,6 @@ def main(
max_seq_len: Optional[int] = 512,
max_batch_size: Optional[int] = 4,
model_parallel_size: Optional[int] = None,
ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.FP8_ROWWISE,
fp8_activation_scale_ub: Optional[float] = 1200.0,
seed: int = 1,
):
@ -99,7 +105,7 @@ def main(
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
print(ckpt_path)
log.info(ckpt_path)
assert (
quantized_ckpt_dir is not None
), "QUantized checkpoint directory should not be None"
@ -112,7 +118,6 @@ def main(
fp8_weight = quantize_fp8(
block.feed_forward.w1.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
@ -124,7 +129,6 @@ def main(
fp8_weight = quantize_fp8(
block.feed_forward.w3.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():
@ -136,7 +140,6 @@ def main(
fp8_weight = quantize_fp8(
block.feed_forward.w2.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cpu"),
)
with torch.inference_mode():

View file

@ -9,7 +9,7 @@
set -euo pipefail
set -x
cd $(git rev-parse --show-toplevel)
cd $(dirname "$(realpath "$0")")
MASTER_HOST=$1
RUN_ID=$2
@ -21,7 +21,7 @@ NPROC=$7
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" \
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" PYTHONPATH="/home/$USER/llama-models:/home/$USER/llama-stack" \
torchrun \
--nnodes=$NNODES --nproc_per_node=$NPROC \
--rdzv_id=$RUN_ID \