cleanup for fp8 and requirements etc

This commit is contained in:
Ashwin Bharambe 2024-07-20 23:21:41 -07:00
parent 2428701951
commit d73fed5cc3
9 changed files with 78 additions and 905 deletions

31
fp8_requirements.txt Normal file
View file

@ -0,0 +1,31 @@
--extra-index-url https://download.pytorch.org/whl/nightly/cu121
torch>=2.4.0.dev20240531,<2.4.1
accelerate
black==24.4.2
codeshield
fairscale
fastapi
fire
flake8
huggingface-hub
httpx
hydra-core
hydra-zen
json-strong-typing
matplotlib
omegaconf
pandas
Pillow
pre-commit
pydantic==1.10.13
pydantic_core==2.18.2
python-dotenv
python-openapi
requests
tiktoken
transformers
ufmt==2.7.0
usort==1.0.8
uvicorn
zmq
fbgemm-gpu==0.8.0rc4

View file

@ -1,45 +0,0 @@
#!/bin/bash
if [[ $# -ne 1 ]]; then
echo "Error: Please provide the name of CONDA environment you wish to create"
exit 1
fi
ENV_NAME=$1
set -eu
eval "$(conda shell.bash hook)"
echo "Will build env (or overwrite) named '$ENV_NAME'"
set -x
run_build() {
# Set CUDA 9.0a targets
export CUDA_ARCH_LIST="8.0;9.0a"
export NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a"
export TORCH_CUDA_ARCH_LIST=$CUDA_ARCH_LIST
# Set up the conda environment
yes | conda remove --name $ENV_NAME --all
yes | conda create -n $ENV_NAME python=3.10
conda activate $ENV_NAME
yes | conda install --channel "nvidia/label/cuda-12.1.0" cuda
yes | conda install cuda-nvtx cuda-nvtx-dev conda-forge::nccl
# ############# Hack to get CUDA path #############
ln -s $CONDA_PREFIX/targets/x86_64-linux/include/* $CONDA_PREFIX/include/ || true
export CUDA_HOME=$CONDA_PREFIX
export CUDA_BIN_PATH=$CUDA_HOME
# #################################################
# PT nightly
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu121
# install dependencies for `llama-agentic-system`
pip install -r fp8_requirements.txt
}
run_build

View file

@ -1,5 +0,0 @@
fairscale
fire
tiktoken
blobfile
fbgemm-gpu==0.8.0rc4

View file

@ -1,455 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import os
import sys
import time
from pathlib import Path
from typing import List, Optional, Tuple, TypedDict
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from fp8.fp8_impls import (
FfnQuantizeMode,
Fp8ScaledWeights,
load_fp8,
ModelLoadMode,
quantize_fp8,
)
from llama.model import ModelArgs, Transformer, TransformerBlock
from llama.tokenizer import ChatFormat, Dialog, Message, ModelInput, Tokenizer
class CompletionPrediction(TypedDict, total=False):
generation: str
tokens: List[str] # not required
logprobs: List[float] # not required
class ChatPrediction(TypedDict, total=False):
generation: Message
tokens: List[str] # not required
logprobs: List[float] # not required
class Llama:
@staticmethod
def build(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
model_parallel_size: Optional[int] = None,
ffn_quantize_mode: Optional[FfnQuantizeMode] = FfnQuantizeMode.NONE,
model_load_mode: Optional[ModelLoadMode] = ModelLoadMode.BF16,
fp8_activation_scale_ub: Optional[float] = 1200.0,
seed: int = 1,
) -> "Llama":
"""
Build a Llama instance by initializing and loading a model checkpoint.
Args:
ckpt_dir (str): Path to the directory containing checkpoint files.
tokenizer_path (str): Path to the tokenizer file.
max_seq_len (int): Maximum sequence length for input text.
max_batch_size (int): Maximum batch size for inference.
model_parallel_size (Optional[int], optional): Number of model parallel processes.
If not provided, it's determined from the environment. Defaults to None.
Returns:
Llama: An instance of the Llama class with the loaded model and tokenizer.
Raises:
AssertionError: If there are no checkpoint files in the specified directory,
or if the model parallel size does not match the number of checkpoint files.
Note:
This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer.
"""
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
torch.manual_seed(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
tokenizer = Tokenizer(model_path=tokenizer_path)
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
print("ffn_quantize_mode: ", ffn_quantize_mode)
if ffn_quantize_mode == FfnQuantizeMode.FP8_ROWWISE:
# Move weights to GPU with quantization
if model_load_mode == ModelLoadMode.FP8:
fp8_scales_path = os.path.join(
ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
assert os.path.isfile(
fp8_scales_path
), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (
model.n_layers - 1
):
continue
block.feed_forward.w1.weight = load_fp8(
block.feed_forward.w1.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
block.feed_forward.w3.weight = load_fp8(
block.feed_forward.w3.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
block.feed_forward.w2.weight = load_fp8(
block.feed_forward.w2.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
else:
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (
model.n_layers - 1
):
continue
block.feed_forward.w1.weight = quantize_fp8(
block.feed_forward.w1.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cuda"),
)
block.feed_forward.w3.weight = quantize_fp8(
block.feed_forward.w3.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cuda"),
)
block.feed_forward.w2.weight = quantize_fp8(
block.feed_forward.w2.weight,
fp8_activation_scale_ub,
ffn_quantize_mode,
output_device=torch.device("cuda"),
)
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device="cuda")
else:
for _, parameter in model.named_parameters():
parameter.data = parameter.to(device="cuda")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
@torch.inference_mode()
def generate(
self,
model_inputs: List[ModelInput],
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
include_stop_token: bool = False,
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
"""
Generate text sequences based on provided prompts using the language generation model.
Args:
prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
max_gen_len (int): Maximum length of the generated text sequence.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
Returns:
Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.
Note:
This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
If logprobs is True, token log probabilities are computed for each generated token.
"""
params = self.model.params
prompt_tokens = [m.tokens for m in model_inputs]
bsz = len(prompt_tokens)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
assert max_prompt_len <= params.max_seq_len
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id
if min_prompt_len == total_len:
logits = self.model.forward(tokens, prev_pos)
token_logprobs = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens,
reduction="none",
ignore_index=pad_id,
)
stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens))
for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens[:, prev_pos + 1 : cur_pos + 1],
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (
torch.isin(next_token, stop_tokens)
)
prev_pos = cur_pos
if all(eos_reached):
break
if logprobs:
token_logprobs = token_logprobs.tolist()
out_tokens, out_logprobs = [], []
for i, toks in enumerate(tokens.tolist()):
# cut to max gen len
start = 0 if echo else len(prompt_tokens[i])
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
probs = None
if logprobs:
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
# cut to after eos tok if any
for stop_token in self.tokenizer.stop_tokens:
try:
eos_idx = toks.index(stop_token)
if include_stop_token:
eos_idx += 1
toks = toks[:eos_idx]
probs = probs[:eos_idx] if logprobs else None
except ValueError:
pass
out_tokens.append(toks)
out_logprobs.append(probs)
return (out_tokens, out_logprobs if logprobs else None)
def text_completion(
self,
prompts: List[str],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> List[CompletionPrediction]:
"""
Perform text completion for a list of prompts using the language generation model.
Args:
prompts (List[str]): List of text prompts for completion.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence.
If not provided, it's set to the model's maximum sequence length minus 1.
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
Returns:
List[CompletionPrediction]: List of completion predictions, each containing the generated text completion.
Note:
This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness.
If logprobs is True, token log probabilities are computed for each generated token.
"""
if max_gen_len is None:
max_gen_len = self.model.params.max_seq_len - 1
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
generation_tokens, generation_logprobs = self.generate(
model_inputs=[ModelInput(tokens=pt) for pt in prompt_tokens],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
echo=echo,
)
if logprobs:
return [
{
"generation": self.tokenizer.decode(t),
"tokens": [self.tokenizer.decode([x]) for x in t],
"logprobs": logprobs_i,
}
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
]
return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
def chat_completion(
self,
dialogs: List[Dialog],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
) -> List[ChatPrediction]:
"""
Generate assistant responses for a list of conversational dialogs using the language generation model.
Args:
dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
max_gen_len (Optional[int], optional): Maximum length of the generated response sequence.
If not provided, it's set to the model's maximum sequence length minus 1.
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
Returns:
List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response.
Note:
This method generates assistant responses for the provided conversational dialogs.
It employs nucleus sampling to introduce controlled randomness in text generation.
If logprobs is True, token log probabilities are computed for each generated token.
"""
if max_gen_len is None:
max_gen_len = self.model.params.max_seq_len - 1
model_inputs = [
self.formatter.encode_dialog_prompt(dialog) for dialog in dialogs
]
generation_tokens, generation_logprobs = self.generate(
model_inputs=model_inputs,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
include_stop_token=True,
)
if logprobs:
return [
{
"generation": self.formatter.decode_assistant_message(t),
"tokens": [self.tokenizer.decode([x]) for x in t],
"logprobs": logprobs_i,
}
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
]
return [
{
"generation": self.formatter.decode_assistant_message(t),
}
for t in generation_tokens
]
def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

View file

@ -5,7 +5,6 @@ import os
from typing import Optional
import torch
from torch import Tensor
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from models.llama3_1.api.model import Transformer, TransformerBlock
@ -17,6 +16,7 @@ from toolchain.inference.api.config import (
InlineImplConfig,
)
from toolchain.inference.api.datatypes import QuantizationType
from torch import Tensor
def is_fbgemm_available() -> bool:
@ -69,27 +69,15 @@ def convert_to_quantized_model(
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
block.feed_forward.w1.weight = load_fp8(
block.feed_forward.w1.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
block.feed_forward.w3.weight = load_fp8(
block.feed_forward.w3.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
block.feed_forward.w2.weight = load_fp8(
block.feed_forward.w2.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = load_fp8(
param.weight,
fp8_scales[
f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"
],
fp8_activation_scale_ub,
)
else:
cprint("Quantizing fp8 weights from bf16...", "yellow")
for block in model.layers:
@ -97,21 +85,13 @@ def convert_to_quantized_model(
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
block.feed_forward.w1.weight = quantize_fp8(
block.feed_forward.w1.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
block.feed_forward.w3.weight = quantize_fp8(
block.feed_forward.w3.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
block.feed_forward.w2.weight = quantize_fp8(
block.feed_forward.w2.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = quantize_fp8(
param.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights):

View file

@ -1,363 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
)
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from fp8.fp8_impls import ffn_swiglu
from torch import nn
@dataclass
class QuantizationArgs:
fp8_rowwise: bool = False
convert_from_bf16: bool = False
@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = -1
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
rope_theta: float = 500000
use_scaled_rope: bool = False
quantization: Optional[QuantizationArgs] = None
max_batch_size: int = 32
max_seq_len: int = 2048
def __init__(self, **kwargs):
for k, v in kwargs.items():
if hasattr(self, k):
setattr(self, k, v)
if self.n_kv_heads is None:
self.n_kv_heads = self.n_heads
assert self.n_kv_heads <= self.n_heads
assert self.n_heads % self.n_kv_heads == 0
assert self.dim % self.n_heads == 0
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
def apply_scaling(freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
def precompute_freqs_cis(
dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled:
freqs = apply_scaling(freqs)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(
keys, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(
values, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(
1, 2
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinear(
dim,
hidden_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.w3 = ColumnParallelLinear(
dim,
hidden_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.w2 = RowParallelLinear(
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
)
def forward(self, x):
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
return reduce_from_model_parallel_region(out)
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.tok_embeddings = VocabParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
)
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
)
self.freqs_cis = precompute_freqs_cis(
params.dim // params.n_heads,
params.max_seq_len * 2,
params.rope_theta,
params.use_scaled_rope,
)
@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=1)
# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack(
[torch.zeros((seqlen, start_pos), device=tokens.device), mask]
).type_as(h)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h).float()
return output

View file

@ -0,0 +1,30 @@
#!/bin/bash
if [[ $# -ne 1 ]]; then
echo "Error: Please provide the name of CONDA environment you wish to create"
exit 1
fi
ENV_NAME=$1
set -eu
eval "$(conda shell.bash hook)"
echo "Will build env (or overwrite) named '$ENV_NAME'"
set -x
run_build() {
# Set up the conda environment
yes | conda remove --name $ENV_NAME --all
yes | conda create -n $ENV_NAME python=3.10
conda activate $ENV_NAME
# PT nightly
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
# install dependencies for `llama-agentic-system`
pip install -r fp8_requirements.txt
}
run_build