forked from phoenix-oss/llama-stack-mirror
refactor: move all llama code to models/llama out of meta reference (#1887)
# What does this PR do? Move around bits. This makes the copies from llama-models _much_ easier to maintain and ensures we don't entangle meta-reference specific tidbits into llama-models code even by accident. Also, kills the meta-reference-quantized-gpu distro and rolls quantization deps into meta-reference-gpu. ## Test Plan ``` LLAMA_MODELS_DEBUG=1 \ with-proxy llama stack run meta-reference-gpu \ --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \ --env INFERENCE_CHECKPOINT_DIR=<DIR> \ --env MODEL_PARALLEL_SIZE=4 \ --env QUANTIZATION_TYPE=fp8_mixed ``` Start a server with and without quantization. Point integration tests to it using: ``` pytest -s -v tests/integration/inference/test_text_inference.py \ --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct ```
This commit is contained in:
parent
c52ccc4bbd
commit
530d4bdfe1
85 changed files with 1267 additions and 1683 deletions
164
llama_stack/models/llama/checkpoint.py
Normal file
164
llama_stack/models/llama/checkpoint.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import concurrent.futures
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
|
||||
|
||||
|
||||
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]:
|
||||
"""Map a new MP rank to a list of old MP ranks given a change in MP size."""
|
||||
if new_mp_size % old_mp_size == 0:
|
||||
# Read old MP shard and split it into smaller ones
|
||||
return [new_mp_rank * old_mp_size // new_mp_size]
|
||||
elif old_mp_size % new_mp_size == 0:
|
||||
# Merge old MP shards into a single one
|
||||
mp_factor = old_mp_size // new_mp_size
|
||||
return list(range(new_mp_rank * mp_factor, (new_mp_rank + 1) * mp_factor))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Either old MP size or new MP size should be a multiple of the other: "
|
||||
f"{old_mp_size} % {new_mp_size} != 0 and {new_mp_size} % {old_mp_size} != 0"
|
||||
)
|
||||
|
||||
|
||||
def maybe_reshard_state_dict(
|
||||
ckpt_paths: List[Path],
|
||||
n_kv_heads: int,
|
||||
moe_num_experts: Optional[int] = None,
|
||||
map_location: Union[str, torch.device] = "cpu",
|
||||
mmap: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if str(map_location) == "cpu":
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
|
||||
ckpt_paths = np.array(sorted(ckpt_paths))
|
||||
|
||||
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
|
||||
old_mp_size = len(ckpt_paths)
|
||||
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
|
||||
|
||||
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
|
||||
paths = ckpt_paths[old_mp_ranks] # type: ignore
|
||||
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
|
||||
|
||||
if new_mp_size == old_mp_size:
|
||||
return state_dicts[0] # type: ignore
|
||||
|
||||
if moe_num_experts is not None:
|
||||
state_dicts = [convert_moe_weights(d, moe_num_experts) for d in state_dicts]
|
||||
|
||||
print(f"Resharding {len(state_dicts)} state dicts from MP size {old_mp_size} to MP size {new_mp_size}")
|
||||
return reshard_mp(
|
||||
state_dicts,
|
||||
size=max(new_mp_size // old_mp_size, 1),
|
||||
rank=new_mp_rank % max(new_mp_size // old_mp_size, 1),
|
||||
repeat_qk_qv=max(new_mp_size // n_kv_heads, 1),
|
||||
)
|
||||
|
||||
|
||||
_WEIGHT_ROW_KEY = {
|
||||
"feed_forward.w2",
|
||||
"feed_forward.mlp.fc2",
|
||||
"attention.wo",
|
||||
"feed_forward.mlp.fc2_weight",
|
||||
"feed_forward.w_out_shared_DF.weight",
|
||||
"attn.wo.weight",
|
||||
"mlp.c_proj.weight",
|
||||
}
|
||||
_MOE_WEIGHT_ROW_KEY = {"feed_forward.experts.(moe_w_in_eD_F|moe_w_swiglu_eD_F)"}
|
||||
|
||||
_WEIGHT_COLUMN_KEY = {
|
||||
"output",
|
||||
"feed_forward.(w1|w3)",
|
||||
"feed_forward.mlp.(fc1|fc3)",
|
||||
"feed_forward.mlp.fc1_weight",
|
||||
"attention.(wk|wq|wv|wqkv).weight",
|
||||
"feed_forward.(w_in_shared_FD|w_swiglu_FD)",
|
||||
"attn.(wk|wq|wv).weight",
|
||||
"attn.(wk|wq|wv).bias",
|
||||
"mlp.c_fc.weight",
|
||||
"mlp.c_fc.bias",
|
||||
"conv1._linear.weight",
|
||||
"tok_embeddings.weight",
|
||||
"vision_projection.weight",
|
||||
}
|
||||
_MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
|
||||
|
||||
|
||||
def reshard_mp(
|
||||
state_dicts: List[Dict[str, torch.Tensor]],
|
||||
size: int,
|
||||
rank: int,
|
||||
repeat_qk_qv: int = 1,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Reshard a list of state dicts into a single state dict given a change in MP size.
|
||||
If the list has more than one state dict, we concatenate the values of the same
|
||||
key across all state dicts. Otherwise, we just slice it for the current MP rank.
|
||||
"""
|
||||
|
||||
def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor:
|
||||
if len(tensors) > 1:
|
||||
return torch.cat(tensors, dim=dim)
|
||||
return tensors[0].chunk(size, dim=dim)[rank].clone()
|
||||
|
||||
def process_key(key: str) -> torch.Tensor:
|
||||
if row_regex.search(key):
|
||||
return concat_or_chunk([s[key] for s in state_dicts], dim=-1)
|
||||
elif column_regex.search(key):
|
||||
if "w13" in key or "fc1_weight" in key:
|
||||
dims = state_dicts[0][key].size()
|
||||
values = [s[key].view(2, dims[0] // 2, *dims[1:]) for s in state_dicts]
|
||||
return concat_or_chunk(values, dim=1).flatten(0, 1)
|
||||
elif "qkv" in key:
|
||||
q_dim = state_dicts[0][key.replace("qkv", "o")].size(1)
|
||||
kv_dim = (state_dicts[0][key].size(0) - q_dim) // 2
|
||||
values = [s[key].split((q_dim, kv_dim, kv_dim)) for s in state_dicts]
|
||||
return torch.cat([concat_or_chunk(x, dim=0) for x in zip(*values, strict=False)]) # type: ignore
|
||||
elif "wk.weight" in key or "wv.weight" in key:
|
||||
# Support MP > #kv_head
|
||||
return concat_or_chunk([s[key].repeat(repeat_qk_qv, 1) for s in state_dicts], dim=0)
|
||||
elif key == "output.bias" or key == "fc.weight":
|
||||
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
|
||||
elif "w_" in key:
|
||||
return concat_or_chunk([s[key] for s in state_dicts], dim=-2)
|
||||
else:
|
||||
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
|
||||
else:
|
||||
return state_dicts[0][key].clone()
|
||||
|
||||
row_keys = _WEIGHT_ROW_KEY | _MOE_WEIGHT_ROW_KEY
|
||||
column_keys = _WEIGHT_COLUMN_KEY | _MOE_WEIGHT_COLUMN_KEY
|
||||
|
||||
column_regex = re.compile("|".join(column_keys))
|
||||
row_regex = re.compile("|".join(row_keys))
|
||||
|
||||
output: Dict[str, torch.Tensor] = {}
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# Note: only processes keys in the first state dict.
|
||||
# Assumes keys are the same across all state dicts.
|
||||
mappings = {executor.submit(process_key, key): key for key in state_dicts[0]}
|
||||
for future in concurrent.futures.as_completed(mappings):
|
||||
output[mappings[future]] = future.result()
|
||||
return output
|
||||
|
||||
|
||||
def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]:
|
||||
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
|
||||
routed_regex = re.compile("|".join(routed_keys))
|
||||
keys = list(state_dict.keys())
|
||||
for key in keys:
|
||||
if routed_regex.search(key):
|
||||
state_dict[key] = state_dict.pop(key).unflatten(0, (num_experts, -1)).squeeze(dim=0)
|
||||
return state_dict
|
|
@ -4,13 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import base64
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
|
@ -19,8 +12,6 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
|||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
# The goal is that these set of types are relevant for all Llama models.
|
||||
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
|
||||
# the llama3 series of models.
|
||||
|
@ -98,6 +89,29 @@ class StopReason(Enum):
|
|||
out_of_tokens = "out_of_tokens"
|
||||
|
||||
|
||||
class ToolParamDefinition(BaseModel):
|
||||
param_type: str
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = True
|
||||
default: Optional[Any] = None
|
||||
|
||||
|
||||
class ToolDefinition(BaseModel):
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
def validate_field(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return BuiltinTool(v)
|
||||
except ValueError:
|
||||
return v
|
||||
return v
|
||||
|
||||
|
||||
class RawMediaItem(BaseModel):
|
||||
type: Literal["image"] = "image"
|
||||
data: bytes | BytesIO
|
||||
|
@ -140,292 +154,25 @@ class RawMessage(BaseModel):
|
|||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema(ToolCall)
|
||||
class GenerationResult(BaseModel):
|
||||
token: int
|
||||
text: str
|
||||
logprobs: Optional[List[float]] = None
|
||||
|
||||
source: Literal["input"] | Literal["output"]
|
||||
|
||||
# index within the batch
|
||||
batch_idx: int
|
||||
# whether generation for this item is already finished. note that tokens can
|
||||
# get returned even afterwards since other items in the batch can still be generating tokens
|
||||
finished: bool
|
||||
# because a batch is parallel processed, useful decoding for one item can correspond to processing
|
||||
# pad tokens or tokens beyond EOS for other items. we could have decided to return None for this case
|
||||
# but it's more convenient to return a list of GenerationResult and filter out the ignored tokens
|
||||
ignore_token: bool
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolParamDefinition(BaseModel):
|
||||
param_type: str
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = True
|
||||
default: Optional[Any] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolDefinition(BaseModel):
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
def validate_field(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return BuiltinTool(v)
|
||||
except ValueError:
|
||||
return v
|
||||
return v
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GreedySamplingStrategy(BaseModel):
|
||||
type: Literal["greedy"] = "greedy"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TopPSamplingStrategy(BaseModel):
|
||||
type: Literal["top_p"] = "top_p"
|
||||
temperature: Optional[float] = Field(..., gt=0.0)
|
||||
top_p: Optional[float] = 0.95
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TopKSamplingStrategy(BaseModel):
|
||||
type: Literal["top_k"] = "top_k"
|
||||
top_k: int = Field(..., ge=1)
|
||||
|
||||
|
||||
SamplingStrategy = Annotated[
|
||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SamplingParams(BaseModel):
|
||||
"""Sampling parameters.
|
||||
|
||||
:param strategy: The sampling strategy.
|
||||
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
|
||||
your prompt plus max_tokens cannot exceed the model's context length.
|
||||
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
|
||||
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
|
||||
:param stop: Up to 4 sequences where the API will stop generating further tokens.
|
||||
The returned text will not contain the stop sequence.
|
||||
"""
|
||||
|
||||
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||
|
||||
max_tokens: Optional[int] = 0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
|
||||
class CheckpointQuantizationFormat(Enum):
|
||||
# default format
|
||||
bf16 = "bf16"
|
||||
|
||||
# used for enabling fp8_rowwise inference, some weights are bf16
|
||||
fp8_mixed = "fp8-mixed"
|
||||
|
||||
int8 = "int8"
|
||||
|
||||
int4 = "int4"
|
||||
|
||||
|
||||
class ModelFamily(Enum):
|
||||
llama2 = "llama2"
|
||||
llama3 = "llama3"
|
||||
llama3_1 = "llama3_1"
|
||||
llama3_2 = "llama3_2"
|
||||
llama3_3 = "llama3_3"
|
||||
llama4 = "llama4"
|
||||
safety = "safety"
|
||||
|
||||
|
||||
class CoreModelId(Enum):
|
||||
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
|
||||
|
||||
# Llama 2 family
|
||||
llama2_7b = "Llama-2-7b"
|
||||
llama2_13b = "Llama-2-13b"
|
||||
llama2_70b = "Llama-2-70b"
|
||||
llama2_7b_chat = "Llama-2-7b-chat"
|
||||
llama2_13b_chat = "Llama-2-13b-chat"
|
||||
llama2_70b_chat = "Llama-2-70b-chat"
|
||||
|
||||
# Llama 3 family
|
||||
llama3_8b = "Llama-3-8B"
|
||||
llama3_70b = "Llama-3-70B"
|
||||
llama3_8b_instruct = "Llama-3-8B-Instruct"
|
||||
llama3_70b_instruct = "Llama-3-70B-Instruct"
|
||||
|
||||
# Llama 3.1 family
|
||||
llama3_1_8b = "Llama3.1-8B"
|
||||
llama3_1_70b = "Llama3.1-70B"
|
||||
llama3_1_405b = "Llama3.1-405B"
|
||||
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
|
||||
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
|
||||
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
|
||||
|
||||
# Llama 3.2 family
|
||||
llama3_2_1b = "Llama3.2-1B"
|
||||
llama3_2_3b = "Llama3.2-3B"
|
||||
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
|
||||
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
|
||||
llama3_2_11b_vision = "Llama3.2-11B-Vision"
|
||||
llama3_2_90b_vision = "Llama3.2-90B-Vision"
|
||||
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
|
||||
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
|
||||
|
||||
# Llama 3.3 family
|
||||
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
|
||||
|
||||
# Llama 4 family
|
||||
llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
|
||||
llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
|
||||
llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
|
||||
llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
|
||||
|
||||
# Safety models
|
||||
llama_guard_3_8b = "Llama-Guard-3-8B"
|
||||
llama_guard_2_8b = "Llama-Guard-2-8B"
|
||||
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
|
||||
llama_guard_3_1b = "Llama-Guard-3-1B"
|
||||
|
||||
|
||||
def is_multimodal(model_id) -> bool:
|
||||
if model_id in [
|
||||
CoreModelId.llama3_2_11b_vision,
|
||||
CoreModelId.llama3_2_90b_vision,
|
||||
CoreModelId.llama3_2_11b_vision_instruct,
|
||||
CoreModelId.llama3_2_90b_vision_instruct,
|
||||
]:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def model_family(model_id) -> ModelFamily:
|
||||
if model_id in [
|
||||
CoreModelId.llama2_7b,
|
||||
CoreModelId.llama2_13b,
|
||||
CoreModelId.llama2_70b,
|
||||
CoreModelId.llama2_7b_chat,
|
||||
CoreModelId.llama2_13b_chat,
|
||||
CoreModelId.llama2_70b_chat,
|
||||
]:
|
||||
return ModelFamily.llama2
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_8b,
|
||||
CoreModelId.llama3_70b,
|
||||
CoreModelId.llama3_8b_instruct,
|
||||
CoreModelId.llama3_70b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_1_8b,
|
||||
CoreModelId.llama3_1_70b,
|
||||
CoreModelId.llama3_1_405b,
|
||||
CoreModelId.llama3_1_8b_instruct,
|
||||
CoreModelId.llama3_1_70b_instruct,
|
||||
CoreModelId.llama3_1_405b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_1
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_2_1b,
|
||||
CoreModelId.llama3_2_3b,
|
||||
CoreModelId.llama3_2_1b_instruct,
|
||||
CoreModelId.llama3_2_3b_instruct,
|
||||
CoreModelId.llama3_2_11b_vision,
|
||||
CoreModelId.llama3_2_90b_vision,
|
||||
CoreModelId.llama3_2_11b_vision_instruct,
|
||||
CoreModelId.llama3_2_90b_vision_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_2
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_3_70b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_3
|
||||
elif model_id in [
|
||||
CoreModelId.llama4_scout_17b_16e,
|
||||
CoreModelId.llama4_scout_17b_16e_instruct,
|
||||
CoreModelId.llama4_maverick_17b_128e,
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct,
|
||||
]:
|
||||
return ModelFamily.llama4
|
||||
elif model_id in [
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_2_8b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
]:
|
||||
return ModelFamily.safety
|
||||
else:
|
||||
raise ValueError(f"Unknown model family for {model_id}")
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
core_model_id: CoreModelId
|
||||
description: str
|
||||
huggingface_repo: Optional[str] = None
|
||||
recommended_sampling_params: Optional[SamplingParams] = None
|
||||
arch_args: Dict[str, Any]
|
||||
variant: str = ""
|
||||
|
||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||
pth_file_count: int
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
# silence pydantic until we remove the `model_` fields
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def model_family(self) -> ModelFamily:
|
||||
return model_family(self.core_model_id)
|
||||
|
||||
# The SKU is uniquely identified by (model_id, variant) combo
|
||||
def descriptor(self, shorten_default_variant: bool = True) -> str:
|
||||
if not self.variant:
|
||||
return self.core_model_id.value
|
||||
return f"{self.core_model_id.value}:{self.variant}"
|
||||
|
||||
@property
|
||||
def is_instruct_model(self) -> bool:
|
||||
return "instruct" in self.id.name
|
||||
|
||||
# Featured models are shown in the non-exhaustive model list
|
||||
@property
|
||||
def is_featured(self) -> bool:
|
||||
return self.model_family in [
|
||||
ModelFamily.llama3_1,
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.llama4,
|
||||
ModelFamily.safety,
|
||||
]
|
||||
|
||||
@property
|
||||
def max_seq_length(self) -> int:
|
||||
if self.model_family == ModelFamily.llama2:
|
||||
return 4096
|
||||
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
|
||||
return 4096
|
||||
elif self.model_family == ModelFamily.llama3:
|
||||
return 8192
|
||||
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
|
||||
return 131072
|
||||
elif self.model_family == ModelFamily.llama3_2:
|
||||
if self.quantization_format == CheckpointQuantizationFormat.int4:
|
||||
return 8192
|
||||
return 131072
|
||||
elif self.model_family == ModelFamily.llama4:
|
||||
if self.core_model_id in {
|
||||
CoreModelId.llama4_scout_17b_16e,
|
||||
CoreModelId.llama4_maverick_17b_128e,
|
||||
}:
|
||||
return 262144
|
||||
if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
|
||||
return 10485760
|
||||
if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
|
||||
return 1048576
|
||||
elif self.core_model_id in [
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
]:
|
||||
return 131072
|
||||
else:
|
||||
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
|
||||
class QuantizationMode(str, Enum):
|
||||
none = "none"
|
||||
fp8_mixed = "fp8_mixed"
|
||||
int4_mixed = "int4_mixed"
|
||||
|
|
86
llama_stack/models/llama/hadamard_utils.py
Normal file
86
llama_stack/models/llama/hadamard_utils.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def hadamard_transform(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Hadamard transform.
|
||||
|
||||
This function performs the Hadamard transform on the input tensor 'x'.
|
||||
The Hadamard transform is a linear transformation that multiplies the input
|
||||
tensor by the Hadamard matrix of dimension n x n, where n is the size of
|
||||
the last dimension of the input tensor.
|
||||
"""
|
||||
*_, n = x.shape
|
||||
m = int(math.log2(n))
|
||||
assert n == 1 << m, "n must be a power of 2"
|
||||
x = x[..., None]
|
||||
inv_sqrt2 = 0.5**0.5
|
||||
for _ in range(m):
|
||||
top = x[..., ::2, :] + x[..., 1::2, :]
|
||||
bot = x[..., ::2, :] - x[..., 1::2, :]
|
||||
x = torch.cat((top, bot), dim=-1)
|
||||
x *= inv_sqrt2
|
||||
res = x.squeeze(-2)
|
||||
return res
|
||||
|
||||
|
||||
class HadamardModule(torch.nn.Module):
|
||||
"""A module that applies the Hadamard transform to the input tensor.
|
||||
|
||||
Args:
|
||||
group_size: The size of the groups that the input tensor will be divided into
|
||||
before applying the Hadamard transform.
|
||||
"""
|
||||
|
||||
def __init__(self, group_size: int) -> None:
|
||||
super().__init__()
|
||||
self.group_size = group_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
reshape_back = False
|
||||
orig_shape = x.shape
|
||||
if self.group_size != x.shape[-1]:
|
||||
reshape_back = True
|
||||
x = x.reshape(-1, x.shape[-1] // self.group_size, self.group_size)
|
||||
x = hadamard_transform(x)
|
||||
if reshape_back:
|
||||
x = x.reshape(orig_shape)
|
||||
return x
|
||||
|
||||
|
||||
def add_hadamard_transform_for_spinquant(model: torch.nn.Module, prefix: str = "") -> None:
|
||||
"""
|
||||
Adds a Hadamard transform to the last linear layer of each feedforward network (FFN) in the model.
|
||||
This function recursively traverses the model's children and looks for layers that match the pattern
|
||||
"layers.<digit>.feed_forward.w2", where <digit> is one or more digits. When such a layer is found,
|
||||
it is replaced with a new sequential module that consists of a HadamardModule followed by the original
|
||||
layer. The HadamardModule applies the Hadamard transform to the input tensor.
|
||||
|
||||
See `SpinQuant <https://arxiv.org/abs/2405.16406>_` paper for more details.
|
||||
|
||||
Args:
|
||||
model: An instance of 'torch.nn.Module' (e.g., Transformer model).
|
||||
prefix: A string prefix to add to the full name of each child module.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
pattern_last_linear_ffn = r"layers.\d+.feed_forward.w2"
|
||||
for module_name, module in model.named_children():
|
||||
child_full_name = prefix + "." + module_name
|
||||
if re.search(pattern_last_linear_ffn, child_full_name):
|
||||
new_module = nn.Sequential(HadamardModule(group_size=module.in_features), module)
|
||||
del module
|
||||
setattr(model, module_name, new_module)
|
||||
else:
|
||||
add_hadamard_transform_for_spinquant(module, (prefix + "." if prefix else prefix) + module_name)
|
75
llama_stack/models/llama/llama3/args.py
Normal file
75
llama_stack/models/llama/llama3/args.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class QuantizationScheme(Enum):
|
||||
int4_weight_int8_dynamic_activation = "int4_weight_int8_dynamic_activation"
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantizationArgs:
|
||||
scheme: Optional[QuantizationScheme] = None
|
||||
group_size: Optional[int] = None
|
||||
spinquant: bool = False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if k == "scheme":
|
||||
setattr(self, k, QuantizationScheme(v))
|
||||
else:
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAArgs:
|
||||
rank: int
|
||||
scale: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
dim: int = 4096
|
||||
n_layers: int = 32
|
||||
n_heads: int = 32
|
||||
n_kv_heads: Optional[int] = None
|
||||
vocab_size: int = -1
|
||||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||
ffn_dim_multiplier: Optional[float] = None
|
||||
norm_eps: float = 1e-5
|
||||
rope_theta: float = 500000
|
||||
use_scaled_rope: bool = False
|
||||
|
||||
max_batch_size: int = 32
|
||||
max_seq_len: int = 2048
|
||||
|
||||
# vision model params
|
||||
vision_chunk_size: int = -1 # image resolution for image models
|
||||
vision_max_num_chunks: int = 4
|
||||
vision_num_cross_attention_layers: int = -1
|
||||
|
||||
quantization_args: Optional[QuantizationArgs] = None
|
||||
lora_args: Optional[LoRAArgs] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if k == "lora_args":
|
||||
setattr(self, k, LoRAArgs(**v))
|
||||
elif k == "quantization_args":
|
||||
setattr(self, k, QuantizationArgs(**v))
|
||||
else:
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
if self.n_kv_heads is None:
|
||||
self.n_kv_heads = self.n_heads
|
||||
assert self.n_kv_heads <= self.n_heads
|
||||
assert self.n_heads % self.n_kv_heads == 0
|
||||
assert self.dim % self.n_heads == 0
|
|
@ -4,13 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import io
|
||||
import json
|
||||
import uuid
|
||||
|
@ -19,7 +12,7 @@ from typing import Dict, List, Optional, Tuple
|
|||
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
from ..datatypes import (
|
||||
BuiltinTool,
|
||||
RawContent,
|
||||
RawMediaItem,
|
||||
|
@ -30,7 +23,6 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from .tokenizer import Tokenizer
|
||||
from .tool_utils import ToolUtils
|
||||
|
||||
|
|
367
llama_stack/models/llama/llama3/generation.py
Normal file
367
llama_stack/models/llama/llama3/generation.py
Normal file
|
@ -0,0 +1,367 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from termcolor import cprint
|
||||
|
||||
from ..checkpoint import maybe_reshard_state_dict
|
||||
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
|
||||
from .args import ModelArgs
|
||||
from .chat_format import ChatFormat, LLMInput
|
||||
from .model import Transformer
|
||||
from .multimodal.model import CrossAttentionTransformer
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
|
||||
class Llama3:
|
||||
@staticmethod
|
||||
def build(
|
||||
ckpt_dir: str,
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
world_size: Optional[int] = None,
|
||||
quantization_mode: Optional[QuantizationMode] = None,
|
||||
seed: int = 1,
|
||||
device: str = "cuda",
|
||||
):
|
||||
device = torch.device(device)
|
||||
if (
|
||||
device.type == "cuda"
|
||||
and not torch.cuda.is_available()
|
||||
or device.type == "xpu"
|
||||
and not torch.xpu.is_available()
|
||||
):
|
||||
raise RuntimeError(f"PyTorch backend for {device.type} device type is not available")
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
if device.type == "cuda":
|
||||
torch.distributed.init_process_group("nccl")
|
||||
else:
|
||||
torch.distributed.init_process_group("gloo")
|
||||
|
||||
if not model_parallel_is_initialized():
|
||||
if world_size is None:
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
initialize_model_parallel(world_size)
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
if device.type == "cuda":
|
||||
torch.cuda.set_device(local_rank)
|
||||
elif device.type == "xpu":
|
||||
torch.xpu.set_device(local_rank)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if local_rank > 0:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
**params,
|
||||
)
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
|
||||
state_dict = maybe_reshard_state_dict(
|
||||
ckpt_paths,
|
||||
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
|
||||
)
|
||||
|
||||
assert model_args.vocab_size == tokenizer.n_words
|
||||
|
||||
def build_model():
|
||||
if model_args.vision_chunk_size > 0:
|
||||
model = CrossAttentionTransformer(model_args)
|
||||
model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
|
||||
else:
|
||||
model = Transformer(model_args)
|
||||
return model
|
||||
|
||||
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
model = build_model()
|
||||
print("Loading state dict...")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
print("Done...")
|
||||
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
|
||||
torch.set_default_device(device)
|
||||
else:
|
||||
print(f"Setting default device to {device}")
|
||||
torch.set_default_device(device)
|
||||
if device.type == "cuda":
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
else:
|
||||
torch.set_default_dtype(torch.half)
|
||||
elif device.type == "xpu":
|
||||
if torch.xpu.is_bf16_supported():
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
else:
|
||||
torch.set_default_dtype(torch.half)
|
||||
|
||||
model = build_model()
|
||||
print("Loading state dict...")
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
model.to(device)
|
||||
print("Done...")
|
||||
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
return Llama3(model, tokenizer, model_args)
|
||||
|
||||
def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs):
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
model_inputs: List[LLMInput],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
print_model_input: bool = False,
|
||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
params = self.model.params
|
||||
|
||||
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
||||
if print_model_input:
|
||||
for inp in model_inputs:
|
||||
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
|
||||
cprint(
|
||||
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
||||
"red",
|
||||
)
|
||||
prompt_tokens = [inp.tokens for inp in model_inputs]
|
||||
|
||||
bsz = len(model_inputs)
|
||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||
|
||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||
|
||||
if max_prompt_len >= params.max_seq_len:
|
||||
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
|
||||
return
|
||||
|
||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||
|
||||
pad_id = self.tokenizer.pad_id
|
||||
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
|
||||
for k, t in enumerate(prompt_tokens):
|
||||
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
|
||||
if logprobs:
|
||||
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
|
||||
|
||||
is_vision = not isinstance(self.model, Transformer)
|
||||
if is_vision:
|
||||
images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
|
||||
mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
|
||||
|
||||
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
|
||||
batch_images=images,
|
||||
batch_masks=mask,
|
||||
total_len=total_len,
|
||||
device=tokens.device,
|
||||
)
|
||||
|
||||
eos_reached = torch.tensor([False] * bsz)
|
||||
input_text_mask = tokens != pad_id
|
||||
|
||||
if echo:
|
||||
for i in range(max_prompt_len):
|
||||
results = []
|
||||
for j, t in enumerate(tokens[:, i]):
|
||||
results.append(
|
||||
GenerationResult(
|
||||
token=t.item(),
|
||||
text=self.tokenizer.decode([t.item()]),
|
||||
source="input",
|
||||
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
|
||||
batch_idx=j,
|
||||
finished=False,
|
||||
ignore_token=t.item() == pad_id,
|
||||
)
|
||||
)
|
||||
yield results
|
||||
|
||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
|
||||
|
||||
prev_pos = 0
|
||||
for cur_pos in range(min_prompt_len, total_len):
|
||||
if is_vision:
|
||||
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
||||
text_only_inference = all(inp.vision is None for inp in model_inputs)
|
||||
logits = self.model.forward(
|
||||
position_ids,
|
||||
tokens,
|
||||
cross_attention_masks,
|
||||
full_text_row_masked_out_mask,
|
||||
xattn_caches,
|
||||
text_only_inference,
|
||||
)
|
||||
else:
|
||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
|
||||
if logits_processor is not None:
|
||||
logits = logits_processor(tokens[:, :cur_pos], logits)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
next_token = sample_top_p(probs, top_p)
|
||||
else:
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||
|
||||
next_token = next_token.reshape(-1)
|
||||
# only replace token if prompt has already been generated
|
||||
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
||||
tokens[:, cur_pos] = next_token
|
||||
|
||||
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
||||
if is_vision:
|
||||
# the logits space (num_classes) is designed to never contain a media_token
|
||||
# however our input token stream does contain them. we need to nuke them here
|
||||
# or else the CUDA kernels will crash with an illegal memory access
|
||||
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
|
||||
masks = [target.eq(t) for t in vision_tokens]
|
||||
if len(masks) > 1:
|
||||
mask = torch.logical_or(*masks)
|
||||
else:
|
||||
mask = masks[0]
|
||||
target[mask] = 0
|
||||
|
||||
if logprobs:
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
||||
input=logits.transpose(1, 2),
|
||||
target=target,
|
||||
reduction="none",
|
||||
ignore_index=pad_id,
|
||||
)
|
||||
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
|
||||
results = []
|
||||
for idx, t in enumerate(next_token):
|
||||
results.append(
|
||||
GenerationResult(
|
||||
token=t.item(),
|
||||
text=self.tokenizer.decode([t.item()]),
|
||||
source="output",
|
||||
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
||||
batch_idx=idx,
|
||||
finished=eos_reached[idx],
|
||||
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
||||
)
|
||||
)
|
||||
yield results
|
||||
|
||||
prev_pos = cur_pos
|
||||
if all(eos_reached):
|
||||
break
|
||||
|
||||
def completion(
|
||||
self,
|
||||
contents: List[RawContent],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
model_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||
for result in self.generate(
|
||||
model_inputs=model_inputs,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
):
|
||||
yield result
|
||||
if all(r.finished for r in result):
|
||||
break
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages_batch: List[List[RawMessage]],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||
for result in self.generate(
|
||||
model_inputs=model_inputs,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
):
|
||||
yield result
|
||||
if all(r.finished for r in result):
|
||||
break
|
||||
|
||||
|
||||
def sample_top_p(probs, p):
|
||||
"""
|
||||
Perform top-p (nucleus) sampling on a probability distribution.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Probability distribution tensor.
|
||||
p (float): Probability threshold for top-p sampling.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Sampled token indices.
|
||||
|
||||
Note:
|
||||
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
||||
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
||||
"""
|
||||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = probs_sum - probs_sort > p
|
||||
probs_sort[mask] = 0.0
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
|
@ -16,7 +16,7 @@ from typing import List, Optional
|
|||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
from ..datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
|
@ -24,7 +24,6 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from . import template_data
|
||||
from .chat_format import ChatFormat
|
||||
from .prompt_templates import (
|
||||
|
|
305
llama_stack/models/llama/llama3/model.py
Normal file
305
llama_stack/models/llama/llama3/model.py
Normal file
|
@ -0,0 +1,305 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.layers import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from .args import ModelArgs
|
||||
|
||||
# **NOTE**: This code is not runnable without installing `torch` and `fairscale`
|
||||
# dependencies. These dependencies are not part of the default dependencies
|
||||
# (requirements.txt) of the `llama-models` package.
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
|
||||
|
||||
def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
|
||||
# Values obtained from grid search
|
||||
scale_factor = 8
|
||||
low_freq_factor = 1
|
||||
high_freq_factor = 4
|
||||
old_context_len = 8192 # original llama3 length
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
wavelen = 2 * torch.pi / freqs
|
||||
new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs)
|
||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
return torch.where(
|
||||
(wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
|
||||
(1 - smooth) * new_freqs / scale_factor + smooth * new_freqs,
|
||||
new_freqs,
|
||||
)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
||||
if use_scaled:
|
||||
freqs = apply_scaling(freqs)
|
||||
freqs = torch.outer(t, freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
||||
bs, slen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x[:, :, :, None, :]
|
||||
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
||||
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
world_size = fs_init.get_model_parallel_world_size()
|
||||
self.n_local_heads = args.n_heads // world_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // world_size
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
|
||||
self.wq = ColumnParallelLinear(
|
||||
args.dim,
|
||||
args.n_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wk = ColumnParallelLinear(
|
||||
args.dim,
|
||||
self.n_kv_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wv = ColumnParallelLinear(
|
||||
args.dim,
|
||||
self.n_kv_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wo = RowParallelLinear(
|
||||
args.n_heads * self.head_dim,
|
||||
args.dim,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
|
||||
self.cache_k = torch.zeros(
|
||||
(
|
||||
args.max_batch_size,
|
||||
args.max_seq_len,
|
||||
self.n_local_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
)
|
||||
self.cache_v = torch.zeros(
|
||||
(
|
||||
args.max_batch_size,
|
||||
args.max_seq_len,
|
||||
self.n_local_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
):
|
||||
bsz, seqlen, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
self.cache_k = self.cache_k.to(xq)
|
||||
self.cache_v = self.cache_v.to(xq)
|
||||
|
||||
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
|
||||
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
|
||||
|
||||
keys = self.cache_k[:bsz, : start_pos + seqlen]
|
||||
values = self.cache_v[:bsz, : start_pos + seqlen]
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
||||
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
||||
|
||||
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
||||
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
||||
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
||||
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
if mask is not None:
|
||||
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
|
||||
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||
return self.wo(output)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float],
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
|
||||
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_heads = args.n_heads
|
||||
self.dim = args.dim
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=args.dim,
|
||||
hidden_dim=4 * args.dim,
|
||||
multiple_of=args.multiple_of,
|
||||
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, params: ModelArgs):
|
||||
super().__init__()
|
||||
self.params = params
|
||||
self.vocab_size = params.vocab_size
|
||||
self.n_layers = params.n_layers
|
||||
|
||||
self.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)
|
||||
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for layer_id in range(params.n_layers):
|
||||
self.layers.append(TransformerBlock(layer_id, params))
|
||||
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)
|
||||
|
||||
self.freqs_cis = precompute_freqs_cis(
|
||||
params.dim // params.n_heads,
|
||||
params.max_seq_len * 2,
|
||||
params.rope_theta,
|
||||
params.use_scaled_rope,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, tokens: torch.Tensor, start_pos: int):
|
||||
_bsz, seqlen = tokens.shape
|
||||
h = self.tok_embeddings(tokens)
|
||||
self.freqs_cis = self.freqs_cis.to(h.device)
|
||||
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
||||
|
||||
mask = None
|
||||
if seqlen > 1:
|
||||
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
|
||||
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/100005
|
||||
# torch.triu is buggy when the device is mps: filled values are
|
||||
# nan instead of 0.
|
||||
if mask.device.type == torch.device("mps").type:
|
||||
mask = torch.nan_to_num(mask, nan=0.0)
|
||||
|
||||
# When performing key-value caching, we compute the attention scores
|
||||
# only for the new sequence. Thus, the matrix of scores is of size
|
||||
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
|
||||
# j > cache_len + i, since row i corresponds to token cache_len + i.
|
||||
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
|
||||
|
||||
for layer in self.layers:
|
||||
h = layer(h, start_pos, freqs_cis, mask)
|
||||
h = self.norm(h)
|
||||
output = self.output(h).float()
|
||||
return output
|
12
llama_stack/models/llama/llama3/multimodal/__init__.py
Normal file
12
llama_stack/models/llama/llama3/multimodal/__init__.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
179
llama_stack/models/llama/llama3/multimodal/encoder_utils.py
Normal file
179
llama_stack/models/llama/llama3/multimodal/encoder_utils.py
Normal file
|
@ -0,0 +1,179 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and its affiliates.
|
||||
import math
|
||||
from logging import getLogger
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .utils import get_negative_inf_value, to_2tuple
|
||||
|
||||
logger = getLogger()
|
||||
|
||||
|
||||
def resize_local_position_embedding(orig_pos_embed, grid_size):
|
||||
"""
|
||||
Resize position embedding for vision encoder.
|
||||
Original position embedding is [n_tiles * n_tiles + 1, dim]
|
||||
New position embedding will be [grid_size[0] * grid_size[1] + 1, dim]
|
||||
"""
|
||||
new_grid_size = to_2tuple(grid_size)
|
||||
orig_grid_size = to_2tuple(int(math.sqrt(len(orig_pos_embed) - 1)))
|
||||
|
||||
new_pos_emb_tok, new_pos_emb_img = (
|
||||
orig_pos_embed[:1],
|
||||
orig_pos_embed[1:],
|
||||
)
|
||||
logger.info(f"resizing position embedding grid-size from {orig_grid_size} to {new_grid_size}")
|
||||
|
||||
new_pos_emb_img = new_pos_emb_img.reshape(1, orig_grid_size[0], orig_grid_size[1], -1).permute(0, 3, 1, 2)
|
||||
|
||||
new_pos_emb_img = F.interpolate(
|
||||
new_pos_emb_img,
|
||||
size=new_grid_size,
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1).reshape(1, new_grid_size[0] * new_grid_size[1], -1)[0]
|
||||
new_pos_embed = torch.cat([new_pos_emb_tok, new_pos_emb_img], dim=0)
|
||||
return new_pos_embed
|
||||
|
||||
|
||||
def initialize_global_position_embedding_from_local(pos_and_cls_embed, grid_size, x_scale, y_scale):
|
||||
"""
|
||||
Takes a local position embedding for vision encoder and uses it
|
||||
to initialize the global position embedding.
|
||||
Input: local position embedding of shape [grid_size[0] * grid_size[1] + 1, dim]
|
||||
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
|
||||
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
|
||||
"""
|
||||
pos_embed = pos_and_cls_embed[1:]
|
||||
cls_embed = pos_and_cls_embed[0].view(1, 1, 1, -1)
|
||||
grid_size = to_2tuple(grid_size)
|
||||
new_pos_emb_img = pos_embed.reshape(1, grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2)
|
||||
new_grid_size = (x_scale * grid_size[0], y_scale * grid_size[1])
|
||||
new_pos_emb_img = F.interpolate(
|
||||
new_pos_emb_img,
|
||||
size=new_grid_size,
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1)
|
||||
new_pos_emb_img = new_pos_emb_img.view(x_scale, grid_size[0], y_scale, grid_size[1], -1)
|
||||
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 1, 3, 4).contiguous()
|
||||
new_pos_emb_img = new_pos_emb_img.reshape(x_scale, y_scale, grid_size[0] * grid_size[1], -1)
|
||||
cls_embed = cls_embed.expand(x_scale, y_scale, -1, -1)
|
||||
pos_and_cls_embed = torch.cat([cls_embed, new_pos_emb_img], dim=2)
|
||||
return pos_and_cls_embed
|
||||
|
||||
|
||||
def resize_global_position_embedding(pos_and_cls_embed, grid_size, x_scale, y_scale):
|
||||
"""
|
||||
Takes a global position embedding for vision encoder and resizes it to new size.
|
||||
Input: global position embedding of shape [x_old, y_old, old_grid_size[0] * old_grid_size[1] + 1, dim]
|
||||
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
|
||||
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
|
||||
"""
|
||||
# first remove cls token
|
||||
pos_embed = pos_and_cls_embed[:, :, 1:]
|
||||
cls_embed = pos_and_cls_embed[:, :, 0].unsqueeze(2)
|
||||
|
||||
xs_old, ys_old, ntok, dim = pos_embed.shape
|
||||
old_grid_size = int(math.sqrt(ntok))
|
||||
|
||||
# move to correct form for interpolation
|
||||
pos_embed = pos_embed.view(xs_old, ys_old, old_grid_size, old_grid_size, dim)
|
||||
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
|
||||
pos_embed = pos_embed.view(xs_old * old_grid_size, ys_old * old_grid_size, dim)
|
||||
pos_embed = pos_embed.unsqueeze(0)
|
||||
|
||||
# interpolate
|
||||
new_size = (grid_size[0] * x_scale, grid_size[1] * y_scale)
|
||||
pos_embed = pos_embed.permute(0, 3, 1, 2)
|
||||
pos_embed_resized = F.interpolate(
|
||||
pos_embed,
|
||||
size=new_size,
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
pos_embed = pos_embed_resized.permute(0, 2, 3, 1)[0]
|
||||
|
||||
# move it back in place
|
||||
pos_embed = pos_embed.view(x_scale, grid_size[0], y_scale, grid_size[1], dim)
|
||||
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
|
||||
pos_embed = pos_embed.view(x_scale, y_scale, grid_size[0] * grid_size[1], dim)
|
||||
|
||||
# interpolate cls token
|
||||
cls_embed = cls_embed.permute(2, 3, 0, 1)
|
||||
cls_embed_resized = F.interpolate(
|
||||
cls_embed,
|
||||
size=(x_scale, y_scale),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
cls_embed = cls_embed_resized.permute(2, 3, 0, 1)
|
||||
# add cls token back in
|
||||
pos_and_cls_embed = torch.cat([cls_embed, pos_embed], dim=2)
|
||||
|
||||
return pos_and_cls_embed
|
||||
|
||||
|
||||
def build_encoder_attention_mask(
|
||||
x: torch.Tensor,
|
||||
ar: torch.Tensor,
|
||||
ntok: int,
|
||||
num_chunks: int,
|
||||
n_heads: int,
|
||||
):
|
||||
"""
|
||||
Build vision encoder attention mask that omits padding tokens.
|
||||
"""
|
||||
masks = []
|
||||
for arx in ar:
|
||||
mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype)
|
||||
mask_i[: arx[0] * arx[1], :ntok] = 0
|
||||
mask_i = mask_i.view(num_chunks * x.shape[2], -1)
|
||||
mask_i = mask_i @ mask_i.T * get_negative_inf_value(x.dtype)
|
||||
mask_i = mask_i.unsqueeze(0)
|
||||
masks.append(mask_i)
|
||||
masks = torch.stack(masks).to(x.device).expand(-1, n_heads, -1, -1)
|
||||
return masks
|
||||
|
||||
|
||||
def expand_num_tokens_to_mult8(x):
|
||||
num_pad_tokens = 8 - (x.shape[-2] % 8)
|
||||
if num_pad_tokens == 0:
|
||||
return x, 0
|
||||
else:
|
||||
return (
|
||||
torch.cat(
|
||||
[
|
||||
x,
|
||||
torch.zeros(
|
||||
(x.shape[0], x.shape[1], num_pad_tokens, x.shape[-1]),
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
),
|
||||
],
|
||||
dim=-2,
|
||||
),
|
||||
num_pad_tokens,
|
||||
)
|
||||
|
||||
|
||||
def contract_num_tokens_from_mult8(x, num_pad_tokens):
|
||||
if num_pad_tokens == 0:
|
||||
return x
|
||||
return x[:, :, :-num_pad_tokens]
|
408
llama_stack/models/llama/llama3/multimodal/image_transform.py
Normal file
408
llama_stack/models/llama/llama3/multimodal/image_transform.py
Normal file
|
@ -0,0 +1,408 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from logging import getLogger
|
||||
from typing import Any, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv
|
||||
from PIL import Image
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
IMAGE_RES = 224
|
||||
|
||||
logger = getLogger()
|
||||
|
||||
|
||||
class VariableSizeImageTransform(object):
|
||||
"""
|
||||
This class accepts images of any size and dynamically resize, pads and chunks it
|
||||
based on the image aspect ratio and the number of image chunks we allow.
|
||||
|
||||
The algorithm will NOT distort the image fit a certain aspect ratio, because
|
||||
that leads to a significant degradation in image quality.
|
||||
|
||||
It can be summarized in 6 steps:
|
||||
1. Find all possible canvas combinations of max_num_chunks;
|
||||
2. Find the best canvas to fit the image;
|
||||
3. Resize without distortion
|
||||
4. Pad
|
||||
5. Normalize
|
||||
6. Chunk
|
||||
|
||||
For example, if an input image is of size 300x800, patch_size of 224,
|
||||
and max_num_chunks = 8, it will find the closest aspect ratio that
|
||||
is allowed within 8 image chunks, with some restrictions.
|
||||
In this case, 2:4 = 2 horizontal patches and 4 vertical patches,
|
||||
giving a total of 8 chunks.
|
||||
|
||||
If resize_to_max_canvas, the image will be resized (without distortion),
|
||||
to the largest possible resolution. In this case, 388:896, and padded to 448:896,
|
||||
where we maintain the original aspect ratio and pad with zeros value for the rest.
|
||||
This approach minimizes the amount of padding required for any arbitrary resolution.
|
||||
|
||||
However, if limit_upscaling_to_patch_size is set to True,
|
||||
the upscaling will be limited to the patch size. In the example above,
|
||||
the image would remain 300x800 (no upscaling), and then padded to 448:896.
|
||||
|
||||
The final output will therefore be of shape (8, 3, 224, 224), where 2x4
|
||||
patches are coming from the resizing and chunking.
|
||||
"""
|
||||
|
||||
def __init__(self, size: int = IMAGE_RES) -> None:
|
||||
self.size = size
|
||||
logger.info(f"VariableSizeImageTransform size: {self.size}")
|
||||
self.to_tensor = tv.ToTensor()
|
||||
self._mean = (0.48145466, 0.4578275, 0.40821073)
|
||||
self._std = (0.26862954, 0.26130258, 0.27577711)
|
||||
self.normalize = tv.Normalize(
|
||||
mean=self._mean,
|
||||
std=self._std,
|
||||
inplace=True,
|
||||
)
|
||||
self.resample = tv.InterpolationMode.BILINEAR
|
||||
|
||||
@staticmethod
|
||||
def get_factors(n: int) -> Set[int]:
|
||||
"""
|
||||
Calculate all factors of a given number, i.e. a dividor that leaves
|
||||
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
||||
|
||||
Args:
|
||||
n (int): The number to find factors for.
|
||||
|
||||
Returns:
|
||||
set: A set containing all factors of the number.
|
||||
"""
|
||||
factors_set = set()
|
||||
|
||||
for i in range(1, int(n**0.5) + 1):
|
||||
if n % i == 0:
|
||||
factors_set.add(i)
|
||||
factors_set.add(n // i)
|
||||
return factors_set
|
||||
|
||||
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Computes all of the allowed resoltuions for a fixed number of chunks
|
||||
and patch_size. Useful for when dividing an image into chunks.
|
||||
|
||||
Args:
|
||||
max_num_chunks (int): Maximum number of chunks for processing.
|
||||
patch_size (int): Size of the side of the patch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: List of possible resolutions as tuples (height, width).
|
||||
|
||||
Example:
|
||||
>>> max_num_chunks = 5
|
||||
>>> patch_size = 224
|
||||
>>> find_supported_resolutions(max_num_chunks, patch_size)
|
||||
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
|
||||
(672, 224), (224, 448), (448, 224)])
|
||||
|
||||
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
|
||||
{
|
||||
0.25: [(1, 4)],
|
||||
1.0: [(2, 2), (1, 1)],
|
||||
4.0: [(4, 1)],
|
||||
0.33: [(1, 3)],
|
||||
3.0: [(3, 1)],
|
||||
0.5: [(1, 2)],
|
||||
2.0: [(2, 1)]
|
||||
}
|
||||
|
||||
and return the resolutions multiplied by the patch_size:
|
||||
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
|
||||
"""
|
||||
asp_dict = defaultdict(list)
|
||||
for chunk_size in range(max_num_chunks, 0, -1):
|
||||
_factors = sorted(self.get_factors(chunk_size))
|
||||
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
|
||||
for height, width in _asp_ratios:
|
||||
ratio_float = height / width
|
||||
asp_dict[ratio_float].append((height, width))
|
||||
|
||||
# get the resolutions multiplied by the patch_size
|
||||
possible_resolutions = []
|
||||
for value in asp_dict.values():
|
||||
for height, depth in value:
|
||||
possible_resolutions.append((height * patch_size, depth * patch_size))
|
||||
|
||||
return possible_resolutions
|
||||
|
||||
@staticmethod
|
||||
def get_max_res_without_distortion(
|
||||
image_size: Tuple[int, int],
|
||||
target_size: Tuple[int, int],
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the maximum resolution to which an image can be resized to without distorting its
|
||||
aspect ratio, based on the target resolution.
|
||||
|
||||
Args:
|
||||
image_size (Tuple[int, int]): The original resolution of the image (height, width).
|
||||
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
|
||||
Returns:
|
||||
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
|
||||
Example:
|
||||
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
|
||||
(134, 200)
|
||||
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
|
||||
(450, 338)
|
||||
"""
|
||||
|
||||
original_width, original_height = image_size
|
||||
target_width, target_height = target_size
|
||||
|
||||
scale_w = target_width / original_width
|
||||
scale_h = target_height / original_height
|
||||
|
||||
if scale_w < scale_h:
|
||||
new_width = target_width
|
||||
new_height = min(math.floor(original_height * scale_w), target_height)
|
||||
else:
|
||||
new_height = target_height
|
||||
new_width = min(math.floor(original_width * scale_h), target_width)
|
||||
|
||||
return new_width, new_height
|
||||
|
||||
def _pad(self, image: Image.Image, target_size) -> Image.Image:
|
||||
new_width, new_height = target_size
|
||||
new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
|
||||
new_im.paste(image)
|
||||
return new_im
|
||||
|
||||
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
|
||||
# Split image into number of required tiles (width x height)
|
||||
num_channels, height, width = image.size()
|
||||
image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
|
||||
# Permute dimensions to reorder the axes
|
||||
image = image.permute(1, 3, 0, 2, 4).contiguous()
|
||||
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
|
||||
image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
|
||||
return image
|
||||
|
||||
def resize_without_distortion(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
target_size: Tuple[int, int],
|
||||
max_upscaling_size: Optional[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Used to resize an image to target_resolution, without distortion.
|
||||
|
||||
If target_size requires upscaling the image, the user can set max_upscaling_size to
|
||||
limit the upscaling to a maximum size. In this case, since we rescale without distortion,
|
||||
modifying target_size works as a boundary for the image's largest side.
|
||||
|
||||
Args:
|
||||
resample (str): Resampling method used when resizing images.
|
||||
Supports "nearest", "nearest_exact", "bilinear", "bicubic".
|
||||
max_upscaling_size (int): The maximum size to upscale the image to.
|
||||
If None, there is no limit.
|
||||
Examples:
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 600
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(600, 300) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 600
|
||||
>>> image_size = (2000, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 100) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 2000
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 500) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = None
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 500) # new_size_without_distortion
|
||||
"""
|
||||
|
||||
image_width, image_height = image.size
|
||||
image_size = (image_width, image_height)
|
||||
|
||||
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
|
||||
if max_upscaling_size is not None:
|
||||
new_target_width = min(max(image_width, max_upscaling_size), target_size[0])
|
||||
new_target_height = min(max(image_height, max_upscaling_size), target_size[1])
|
||||
target_size = (new_target_width, new_target_height)
|
||||
|
||||
# resize to target_size while preserving aspect ratio
|
||||
new_size_without_distortion = self.get_max_res_without_distortion(image_size, target_size)
|
||||
|
||||
image = F.resize(
|
||||
image,
|
||||
(new_size_without_distortion[1], new_size_without_distortion[0]),
|
||||
interpolation=self.resample,
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
def get_best_fit(
|
||||
self,
|
||||
image_size: Tuple[int, int],
|
||||
possible_resolutions: torch.Tensor,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
||||
resize an image to.
|
||||
|
||||
For each possible resolution, calculates the scaling factors for
|
||||
width and height, and selects the smallest one, which is the limiting side.
|
||||
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
|
||||
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
|
||||
|
||||
If upscaling is possible (any of the scaling factors is greater than 1),
|
||||
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
|
||||
|
||||
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
|
||||
reduce downscaling as much as possible.
|
||||
|
||||
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
|
||||
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
|
||||
has more padding.
|
||||
|
||||
Args:
|
||||
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
|
||||
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
|
||||
row represents a possible resolution (height, width).
|
||||
use_max_upscaling (bool): If True, will return the largest upscaling resolution.
|
||||
|
||||
Returns:
|
||||
List[int]: The best resolution [height, width] for the given image.
|
||||
|
||||
Example:
|
||||
>>> image_size = (200, 300)
|
||||
>>> possible_resolutions = torch.tensor([[224, 672],
|
||||
... [672, 224],
|
||||
... [224, 448],
|
||||
... [448, 224],
|
||||
... [224, 224]])
|
||||
>>> _get_smallest_upscaling_possibility(image_size, possible_resolutions)
|
||||
[224, 448]
|
||||
|
||||
We have:
|
||||
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
|
||||
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
|
||||
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
|
||||
Only one of the scales > 1:
|
||||
upscaling_possible = tensor([1.1200, 1.1200])
|
||||
smallest_rescale = tensor(1.1200)
|
||||
So we pick the resolution with the smallest smallest area:
|
||||
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
|
||||
optimal_canvas = tensor([224, 448])
|
||||
"""
|
||||
|
||||
original_width, original_height = image_size
|
||||
|
||||
# get all possible resolutions heights/widths
|
||||
target_widths, target_heights = (
|
||||
possible_resolutions[:, 0],
|
||||
possible_resolutions[:, 1],
|
||||
)
|
||||
|
||||
# get scaling factors to resize the image without distortion
|
||||
scale_w = target_widths / original_width
|
||||
scale_h = target_heights / original_height
|
||||
|
||||
# get the min scale between width and height (limiting side -> no distortion)
|
||||
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
|
||||
|
||||
# filter only scales that allow upscaling
|
||||
upscaling_options = scales[scales >= 1]
|
||||
if len(upscaling_options) > 0:
|
||||
if resize_to_max_canvas:
|
||||
selected_scale = torch.max(upscaling_options)
|
||||
else:
|
||||
selected_scale = torch.min(upscaling_options)
|
||||
else:
|
||||
# no upscaling possible,
|
||||
# get the minimum downscaling (max scale for scales<1)
|
||||
downscaling_options = scales[scales < 1]
|
||||
selected_scale = torch.max(downscaling_options)
|
||||
|
||||
# get all resolutions that support this scaling factor,
|
||||
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
|
||||
chosen_canvas = possible_resolutions[scales == selected_scale]
|
||||
|
||||
# if there are multiple resolutions,
|
||||
# get the one with minimum area to reduce padding
|
||||
if len(chosen_canvas) > 1:
|
||||
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
|
||||
optimal_idx = torch.argmin(areas)
|
||||
optimal_canvas = chosen_canvas[optimal_idx]
|
||||
else:
|
||||
optimal_canvas = chosen_canvas[0]
|
||||
|
||||
return tuple(optimal_canvas.tolist())
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
image: Image.Image,
|
||||
max_num_chunks: int,
|
||||
normalize_img: bool = True,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Args:
|
||||
image (PIL.Image): Image to be resized.
|
||||
max_num_chunks (int): Maximum number of chunks to split the image into.
|
||||
normalize_img (bool): Whether to normalize the image.
|
||||
resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size.
|
||||
If True, picks the canvas the allows the largest resizing without distortion.
|
||||
If False, downsample as little as possible, including no resizing at all,
|
||||
but never upsample, unless the image is smaller than the patch size.
|
||||
"""
|
||||
assert max_num_chunks > 0
|
||||
assert isinstance(image, Image.Image), type(image)
|
||||
w, h = image.size
|
||||
|
||||
possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
|
||||
possible_resolutions = torch.tensor(possible_resolutions)
|
||||
|
||||
best_resolution = self.get_best_fit(
|
||||
image_size=(w, h),
|
||||
possible_resolutions=possible_resolutions,
|
||||
resize_to_max_canvas=resize_to_max_canvas,
|
||||
)
|
||||
|
||||
max_upscaling_size = None if resize_to_max_canvas else self.size
|
||||
image = self.resize_without_distortion(image, best_resolution, max_upscaling_size)
|
||||
image = self._pad(image, best_resolution)
|
||||
|
||||
image = self.to_tensor(image)
|
||||
|
||||
if normalize_img:
|
||||
image = self.normalize(image)
|
||||
|
||||
ratio_w, ratio_h = (
|
||||
best_resolution[0] // self.size,
|
||||
best_resolution[1] // self.size,
|
||||
)
|
||||
|
||||
image = self._split(image, ratio_w, ratio_h) # type: ignore
|
||||
|
||||
ar = (ratio_h, ratio_w)
|
||||
return image, ar
|
1428
llama_stack/models/llama/llama3/multimodal/model.py
Normal file
1428
llama_stack/models/llama/llama3/multimodal/model.py
Normal file
File diff suppressed because it is too large
Load diff
26
llama_stack/models/llama/llama3/multimodal/utils.py
Normal file
26
llama_stack/models/llama/llama3/multimodal/utils.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import collections
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_negative_inf_value(dtype):
|
||||
return torch.finfo(dtype).min
|
||||
|
||||
|
||||
def to_2tuple(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return (x, x)
|
|
@ -15,7 +15,7 @@ import textwrap
|
|||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
from llama_stack.apis.inference import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
|
|
5
llama_stack/models/llama/llama3/quantization/__init__.py
Normal file
5
llama_stack/models/llama/llama3/quantization/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
316
llama_stack/models/llama/llama3/quantization/loader.py
Normal file
316
llama_stack/models/llama/llama3/quantization/loader.py
Normal file
|
@ -0,0 +1,316 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# type: ignore
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from torch import Tensor, nn
|
||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||
|
||||
from ...datatypes import QuantizationMode
|
||||
from ...quantize_impls import (
|
||||
Fp8ScaledWeights,
|
||||
ffn_swiglu,
|
||||
load_fp8,
|
||||
quantize_fp8,
|
||||
)
|
||||
from ..model import Transformer, TransformerBlock
|
||||
from ..multimodal.model import CrossAttentionTransformer
|
||||
|
||||
|
||||
def swiglu_wrapper(
|
||||
self,
|
||||
x: Tensor,
|
||||
):
|
||||
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
||||
return reduce_from_model_parallel_region(out)
|
||||
|
||||
|
||||
def convert_to_quantized_model(
|
||||
model: Transformer | CrossAttentionTransformer,
|
||||
checkpoint_dir: str,
|
||||
quantization_mode: Optional[str] = None,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Transformer | CrossAttentionTransformer:
|
||||
if quantization_mode == QuantizationMode.fp8_mixed:
|
||||
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
|
||||
elif quantization_mode == QuantizationMode.int4_mixed:
|
||||
return convert_to_int4_quantized_model(model, checkpoint_dir, device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
|
||||
|
||||
|
||||
def convert_to_fp8_quantized_model(
|
||||
model: Transformer,
|
||||
checkpoint_dir: str,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Transformer:
|
||||
# Move weights to GPU with quantization
|
||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
||||
if os.path.isfile(fp8_scales_path):
|
||||
print("Loading fp8 scales...")
|
||||
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
||||
|
||||
for _, block in model.named_modules():
|
||||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
|
||||
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(block.feed_forward, key)
|
||||
param.weight = load_fp8(
|
||||
param.weight,
|
||||
fp8_scales[f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"],
|
||||
fp8_activation_scale_ub,
|
||||
)
|
||||
else:
|
||||
print("Quantizing fp8 weights from bf16...")
|
||||
for _, block in model.named_modules():
|
||||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward) # type: ignore
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(block.feed_forward, key)
|
||||
param.weight = quantize_fp8(
|
||||
param.weight,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=device,
|
||||
)
|
||||
|
||||
for _, parameter in model.named_parameters():
|
||||
if not isinstance(parameter, Fp8ScaledWeights):
|
||||
parameter.data = parameter.to(device=device)
|
||||
return model
|
||||
|
||||
|
||||
class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
||||
"""
|
||||
Int8DynActInt4WeightLinear with LoRA adaptor.
|
||||
|
||||
Args:
|
||||
in_features: Number of input features.
|
||||
out_features: Number of output features.
|
||||
bias: Whether to use bias.
|
||||
device: Device to use.
|
||||
group_size: Group size for quantization.
|
||||
precision: Precision of quantization.
|
||||
scales_precision: Precision of scales.
|
||||
lora_rank: Rank of LoRA adaptor.
|
||||
lora_scale: Scale of LoRA adaptor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias=False,
|
||||
device=None,
|
||||
# quantization parameters
|
||||
group_size: int = 256,
|
||||
precision: torch.dtype = torch.float32,
|
||||
scales_precision: torch.dtype = torch.float32,
|
||||
# LoRA parameters
|
||||
lora_rank: Optional[int] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
groupsize=group_size,
|
||||
precision=precision,
|
||||
scales_precision=scales_precision,
|
||||
)
|
||||
self.lora_scale: Optional[float] = None
|
||||
self.adaptor: Optional[nn.Sequential] = None
|
||||
if lora_rank is not None:
|
||||
assert lora_scale is not None, "Please specify lora scale for LoRA."
|
||||
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
|
||||
self.adaptor = nn.Sequential()
|
||||
self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
|
||||
self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
|
||||
self.lora_scale = lora_scale
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
"""A hook to load the quantized weights from the state dict."""
|
||||
if prefix + "zeros" not in state_dict:
|
||||
# Zero-point may not be saved in the state dict. In this case, we assume it's zero.
|
||||
assert prefix + "scales" in state_dict
|
||||
state_dict[prefix + "zeros"] = torch.zeros_like(state_dict[prefix + "scales"])
|
||||
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
module_out = super().forward(input_)
|
||||
if self.adaptor is not None:
|
||||
adaptor_out = self.adaptor(input_) * self.lora_scale
|
||||
return module_out + adaptor_out
|
||||
return module_out
|
||||
|
||||
|
||||
class Int8WeightEmbedding(torch.nn.Embedding):
|
||||
"""An embedding layer to load int8 weights.
|
||||
|
||||
Args:
|
||||
num_embeddings: Number of embeddings.
|
||||
embedding_dim: Embedding dimension.
|
||||
padding_idx: Padding index.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int,
|
||||
device=None,
|
||||
) -> None:
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx, device=device)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
"""A hook to load the quantized embedding weight and scales from the state dict."""
|
||||
weights = state_dict.pop(prefix + "weight")
|
||||
scales = state_dict.pop(prefix + "scales")
|
||||
state_dict[prefix + "weight"] = weights * scales
|
||||
|
||||
|
||||
class Int8WeightLinear(torch.nn.Linear):
|
||||
"""A linear layer to load int8 weights.
|
||||
|
||||
Args:
|
||||
in_features: Number of input features.
|
||||
out_features: Number of output features.
|
||||
bias: Whether to use bias.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None) -> None:
|
||||
super().__init__(in_features, out_features, bias, device=device)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
"""A hook to load the quantized linear weight and scales from the state dict."""
|
||||
weights = state_dict.pop(prefix + "weight")
|
||||
scales = state_dict.pop(prefix + "scales")
|
||||
state_dict[prefix + "weight"] = weights * scales
|
||||
|
||||
|
||||
def _prepare_model_int4_weight_int8_dynamic_activation(
|
||||
model: torch.nn.Module,
|
||||
group_size: int,
|
||||
lora_rank: Optional[int],
|
||||
lora_scale: Optional[float],
|
||||
):
|
||||
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
|
||||
|
||||
Note that the weights of embedding and output layers are quantized to int8.
|
||||
"""
|
||||
device = None
|
||||
for module_name, module in model.named_children():
|
||||
if module_name == "output":
|
||||
quantized_module = Int8WeightLinear(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias,
|
||||
device=device,
|
||||
)
|
||||
del module
|
||||
setattr(model, module_name, quantized_module)
|
||||
elif module_name == "tok_embeddings":
|
||||
quantized_module = Int8WeightEmbedding(
|
||||
num_embeddings=module.num_embeddings,
|
||||
embedding_dim=module.embedding_dim,
|
||||
padding_idx=module.padding_idx,
|
||||
device=device,
|
||||
)
|
||||
del module
|
||||
setattr(model, module_name, quantized_module)
|
||||
elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)):
|
||||
quantized_module = Int8DynActInt4WeightLinearLoRA(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=False,
|
||||
group_size=group_size,
|
||||
lora_rank=lora_rank,
|
||||
lora_scale=lora_scale,
|
||||
device=device,
|
||||
)
|
||||
del module
|
||||
setattr(model, module_name, quantized_module)
|
||||
else:
|
||||
_prepare_model_int4_weight_int8_dynamic_activation(module, group_size, lora_rank, lora_scale)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def convert_to_int4_quantized_model(
|
||||
model: Transformer | CrossAttentionTransformer,
|
||||
checkpoint_dir: str,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Transformer | CrossAttentionTransformer:
|
||||
"""Convert the model to int4 quantized model."""
|
||||
model_args = model.params
|
||||
assert model_args.quantization_args is not None, "Quantization args must be specified."
|
||||
quantization_args = model_args.quantization_args
|
||||
if quantization_args.scheme is None:
|
||||
raise ValueError("Quantization scheme must be specified in 'quantization_args'.")
|
||||
|
||||
if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
|
||||
raise NotImplementedError(
|
||||
"Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
|
||||
)
|
||||
|
||||
group_size = model_args.quantization_args.group_size
|
||||
if group_size is None:
|
||||
raise ValueError("'group_size' cannot be None in 'quantization_args'. Please specify it.")
|
||||
|
||||
if model_args.lora_args is None:
|
||||
# Certain quantized models (e.g., SpinQuant) may not have LoRA.
|
||||
lora_rank = None
|
||||
lora_scale = None
|
||||
else:
|
||||
lora_rank = model_args.lora_args.rank
|
||||
lora_scale = model_args.lora_args.scale
|
||||
|
||||
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
|
||||
return cast(Transformer | CrossAttentionTransformer, model.to(device=device))
|
|
@ -12,8 +12,7 @@
|
|||
# the top-level of this source tree.
|
||||
|
||||
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
|
||||
from ..datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from .prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
JsonCustomToolGenerator,
|
||||
|
|
|
@ -4,16 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import os
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
|
|
@ -16,7 +16,8 @@ import re
|
|||
from typing import Optional, Tuple
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||
|
||||
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
|
|
@ -3,10 +3,3 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
|
|
@ -4,12 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
import json
|
||||
import textwrap
|
||||
|
||||
|
|
|
@ -4,13 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
|
|
95
llama_stack/models/llama/llama4/args.py
Normal file
95
llama_stack/models/llama/llama4/args.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class QuantizationScheme(Enum):
|
||||
int4_weight_int8_dynamic_activation = "int4_weight_int8_dynamic_activation"
|
||||
|
||||
|
||||
class QuantizationArgs(BaseModel):
|
||||
scheme: Optional[QuantizationScheme] = None
|
||||
group_size: Optional[int] = None
|
||||
spinquant: bool = False
|
||||
|
||||
|
||||
class LoRAArgs(BaseModel):
|
||||
rank: int
|
||||
scale: float
|
||||
|
||||
|
||||
class MoEArgs(BaseModel):
|
||||
num_experts: int = -1
|
||||
capacity_factor: float = 1.0 # capacity factor determines how many tokens each expert can choose
|
||||
auto_scale_F: bool = ( # noqa: N815
|
||||
True # if true, rescales hidden_dim such that number of activated params is same as equivalent dense layer
|
||||
)
|
||||
top_k: int = 1
|
||||
interleave_moe_layer_step: int = 1
|
||||
|
||||
|
||||
class Size(BaseModel):
|
||||
height: int
|
||||
width: int
|
||||
|
||||
|
||||
class VisionArgs(BaseModel):
|
||||
image_size: Size
|
||||
patch_size: Size
|
||||
|
||||
# parameters for the encoder transformer
|
||||
dim: int
|
||||
n_layers: int
|
||||
n_heads: int
|
||||
mlp_ratio: float
|
||||
output_dim: int
|
||||
|
||||
pixel_shuffle_ratio: float
|
||||
|
||||
|
||||
class ModelArgs(BaseModel):
|
||||
dim: int = -1
|
||||
n_layers: int = -1
|
||||
n_heads: int = -1
|
||||
n_kv_heads: Optional[int] = None
|
||||
head_dim: Optional[int] = None
|
||||
|
||||
vocab_size: int = -1
|
||||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||
ffn_dim_multiplier: Optional[float] = None
|
||||
ffn_exp: Optional[float] = None
|
||||
norm_eps: float = 1e-5
|
||||
|
||||
attention_chunk_size: Optional[int] = None
|
||||
rope_theta: float = 500000
|
||||
use_scaled_rope: bool = False
|
||||
nope_layer_interval: Optional[int] = None # No position encoding in every n layers
|
||||
use_qk_norm: bool = False
|
||||
# Set to True to enable inference-time temperature tuning (useful for very long context)
|
||||
attn_temperature_tuning: bool = False
|
||||
floor_scale: float = 8192.0
|
||||
attn_scale: float = 0.1
|
||||
|
||||
vision_args: Optional[VisionArgs] = None
|
||||
moe_args: Optional[MoEArgs] = None
|
||||
quantization_args: Optional[QuantizationArgs] = None
|
||||
lora_args: Optional[LoRAArgs] = None
|
||||
|
||||
max_batch_size: int = 32
|
||||
max_seq_len: int = 2048
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate(self) -> "ModelArgs":
|
||||
assert self.n_kv_heads <= self.n_heads, f"n_kv_heads ({self.n_kv_heads}) must be <= n_heads ({self.n_heads})"
|
||||
assert self.n_heads % self.n_kv_heads == 0, (
|
||||
f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
|
||||
)
|
||||
assert self.dim % self.n_heads == 0, f"dim ({self.dim}) must be divisible by n_heads ({self.n_heads})"
|
||||
return self
|
|
@ -13,7 +13,7 @@ import torch
|
|||
from PIL import Image as PIL_Image
|
||||
|
||||
# TODO: either fork these or move them to the common package
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
from ..datatypes import (
|
||||
BuiltinTool,
|
||||
RawContent,
|
||||
RawMediaItem,
|
||||
|
@ -24,16 +24,10 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.tool_utils import ToolUtils
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama4.args import VisionArgs
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
|
||||
LLMInput,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama4.preprocess import (
|
||||
ResizeNormalizeImageTransform,
|
||||
VariableSizeImageTransform,
|
||||
)
|
||||
|
||||
from ..llama3.tool_utils import ToolUtils
|
||||
from .args import VisionArgs
|
||||
from .datatypes import LLMInput
|
||||
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
|
||||
|
@ -54,7 +48,7 @@ class TransformedImage:
|
|||
aspect_ratio: Tuple[int, int]
|
||||
|
||||
|
||||
def convert_rgba_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
||||
def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
||||
if image.mode == "RGBA":
|
||||
image.load() # for png.split()
|
||||
new_img = PIL_Image.new("RGB", image.size, bg)
|
||||
|
@ -171,7 +165,7 @@ class ChatFormat:
|
|||
|
||||
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
|
||||
image = PIL_Image.open(bytes_io)
|
||||
image = convert_rgba_to_rgb(image)
|
||||
image = convert_image_to_rgb(image)
|
||||
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
|
||||
|
||||
if image_tiles.shape[0] > 1:
|
||||
|
|
57
llama_stack/models/llama/llama4/datatypes.py
Normal file
57
llama_stack/models/llama/llama4/datatypes.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskedEmbedding:
|
||||
embedding: torch.Tensor
|
||||
mask: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMInput:
|
||||
"""
|
||||
This is the input to the LLM from the "user" -- the user in this case views the
|
||||
Llama4 model holistically and does not care or know about its inner workings (e.g.,
|
||||
whether it has an encoder or if it is early fusion or not.)
|
||||
|
||||
This is distinct from the "TransformerInput" class which is really the Llama4
|
||||
backbone operating on early fused modalities and producing text output
|
||||
"""
|
||||
|
||||
tokens: torch.Tensor
|
||||
|
||||
# images are already pre-processed (resized, tiled, etc.)
|
||||
images: Optional[List[torch.Tensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerInput:
|
||||
"""
|
||||
This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities
|
||||
are expected to be "embedded" via encoders sitting before this layer in the model.
|
||||
"""
|
||||
|
||||
tokens: torch.Tensor
|
||||
|
||||
# tokens_position defines the position of the tokens in each batch,
|
||||
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
|
||||
# - when it is an int, the start position are the same for all batches
|
||||
tokens_position: Union[torch.Tensor, int]
|
||||
image_embedding: Optional[MaskedEmbedding] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMOutput:
|
||||
logits: torch.Tensor
|
||||
|
||||
|
||||
TransformerOutput = LLMOutput
|
58
llama_stack/models/llama/llama4/ffn.py
Normal file
58
llama_stack/models/llama/llama4/ffn.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
do_reduce: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.do_reduce = do_reduce
|
||||
|
||||
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
|
||||
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "mlp.fc1_weight" in state_dict:
|
||||
w1, w3 = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2, dim=0)
|
||||
state_dict[prefix + "w1.weight"] = w1
|
||||
state_dict[prefix + "w3.weight"] = w3
|
||||
state_dict[prefix + "w2.weight"] = state_dict.pop(prefix + "mlp.fc2_weight")
|
||||
|
||||
def forward(self, x):
|
||||
x = F.silu(F.linear(x, self.w1.weight)) * F.linear(x, self.w3.weight)
|
||||
out = F.linear(x, self.w2.weight)
|
||||
if self.do_reduce:
|
||||
return reduce_from_model_parallel_region(out)
|
||||
return out
|
313
llama_stack/models/llama/llama4/generation.py
Normal file
313
llama_stack/models/llama/llama4/generation.py
Normal file
|
@ -0,0 +1,313 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import codecs
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from termcolor import cprint
|
||||
|
||||
from ..checkpoint import maybe_reshard_state_dict
|
||||
from ..datatypes import GenerationResult, QuantizationMode
|
||||
from .args import ModelArgs
|
||||
from .chat_format import ChatFormat, RawContent, RawMessage
|
||||
from .datatypes import LLMInput, MaskedEmbedding, TransformerInput
|
||||
from .model import Transformer
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
torch.serialization.add_safe_globals([io.BytesIO, codecs.encode])
|
||||
|
||||
|
||||
class Llama4:
|
||||
@staticmethod
|
||||
def build(
|
||||
ckpt_dir: str,
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
world_size: Optional[int] = None,
|
||||
quantization_mode: Optional[QuantizationMode] = None,
|
||||
seed: int = 1,
|
||||
):
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
||||
if not model_parallel_is_initialized():
|
||||
if world_size is None:
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
initialize_model_parallel(world_size)
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if local_rank > 0:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
**params,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
|
||||
# TODO: params.json should always have correct vocab_size
|
||||
if model_args.vocab_size == -1:
|
||||
model_args.vocab_size = tokenizer.n_words
|
||||
assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch"
|
||||
print("Model args:\n", model_args.model_dump_json(indent=2))
|
||||
|
||||
state_dict = maybe_reshard_state_dict(
|
||||
ckpt_paths,
|
||||
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
|
||||
moe_num_experts=model_args.moe_args.num_experts,
|
||||
)
|
||||
print("Loaded checkpoint")
|
||||
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
model = Transformer(model_args)
|
||||
print("Loading state dict...")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
print("Done...")
|
||||
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode)
|
||||
else:
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
|
||||
model = Transformer(model_args)
|
||||
print("Loading state dict...")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
print("Done...")
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
return Llama4(model, tokenizer, model_args)
|
||||
|
||||
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.formatter = ChatFormat(tokenizer, vision_args=args.vision_args)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
llm_inputs: List[LLMInput],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
print_model_input: bool = False,
|
||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
|
||||
max_gen_len = self.model.args.max_seq_len - 1
|
||||
|
||||
params = self.model.args
|
||||
|
||||
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
||||
if print_model_input:
|
||||
cprint("Input to model:\n", "yellow")
|
||||
for inp in llm_inputs:
|
||||
cprint(self.tokenizer.decode(inp.tokens), "grey")
|
||||
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
||||
|
||||
bsz = len(llm_inputs)
|
||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||
|
||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||
|
||||
if max_prompt_len >= params.max_seq_len:
|
||||
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
|
||||
return
|
||||
|
||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||
|
||||
pad_id = self.tokenizer.pad_id
|
||||
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
|
||||
for k, t in enumerate(prompt_tokens):
|
||||
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
||||
if logprobs:
|
||||
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
|
||||
|
||||
eos_reached = torch.tensor([False] * bsz, device="cuda")
|
||||
input_text_mask = tokens != pad_id
|
||||
|
||||
if echo:
|
||||
for i in range(max_prompt_len):
|
||||
results = []
|
||||
for j, t in enumerate(tokens[:, i]):
|
||||
results.append(
|
||||
GenerationResult(
|
||||
token=t.item(),
|
||||
text=self.tokenizer.decode([t.item()]),
|
||||
source="input",
|
||||
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
|
||||
batch_idx=j,
|
||||
finished=False,
|
||||
ignore_token=t.item() == pad_id,
|
||||
)
|
||||
)
|
||||
yield results
|
||||
|
||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
|
||||
|
||||
prev_pos = 0
|
||||
for cur_pos in range(min_prompt_len, total_len):
|
||||
image_embedding = None
|
||||
if prev_pos == 0 and any(inp.images is not None and len(inp.images) > 0 for inp in llm_inputs):
|
||||
image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"]
|
||||
image_mask = image_mask.unsqueeze(-1)
|
||||
h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos])
|
||||
|
||||
image_batch = [inp.images if inp.images is not None else [] for inp in llm_inputs]
|
||||
image_embedding = MaskedEmbedding(
|
||||
embedding=self.model.vision_embeddings(image_batch, image_mask, h),
|
||||
mask=image_mask,
|
||||
)
|
||||
|
||||
xformer_input = TransformerInput(
|
||||
tokens=tokens[:, prev_pos:cur_pos],
|
||||
tokens_position=prev_pos,
|
||||
image_embedding=image_embedding,
|
||||
)
|
||||
xformer_output = self.model.forward(xformer_input)
|
||||
logits = xformer_output.logits
|
||||
if logits_processor is not None:
|
||||
logits = logits_processor(tokens[:, :cur_pos], logits)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
next_token = sample_top_p(probs, top_p)
|
||||
else:
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||
|
||||
next_token = next_token.reshape(-1)
|
||||
# only replace token if prompt has already been generated
|
||||
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
||||
tokens[:, cur_pos] = next_token
|
||||
|
||||
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
||||
if logprobs:
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
||||
input=logits.transpose(1, 2),
|
||||
target=target,
|
||||
reduction="none",
|
||||
ignore_index=pad_id,
|
||||
)
|
||||
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
|
||||
|
||||
results = []
|
||||
for idx, t in enumerate(next_token):
|
||||
results.append(
|
||||
GenerationResult(
|
||||
token=t.item(),
|
||||
text=self.tokenizer.decode([t.item()]),
|
||||
source="output",
|
||||
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
||||
batch_idx=idx,
|
||||
finished=eos_reached[idx],
|
||||
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
||||
)
|
||||
)
|
||||
yield results
|
||||
|
||||
prev_pos = cur_pos
|
||||
if all(eos_reached):
|
||||
break
|
||||
|
||||
def completion(
|
||||
self,
|
||||
contents: List[RawContent],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
llm_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||
for result in self.generate(
|
||||
llm_inputs=llm_inputs,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
):
|
||||
yield result
|
||||
if all(r.finished for r in result):
|
||||
break
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages_batch: List[List[RawMessage]],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||
for result in self.generate(
|
||||
llm_inputs=llm_inputs,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
):
|
||||
yield result
|
||||
if all(r.finished for r in result):
|
||||
break
|
||||
|
||||
|
||||
def sample_top_p(probs, p):
|
||||
"""
|
||||
Perform top-p (nucleus) sampling on a probability distribution.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Probability distribution tensor.
|
||||
p (float): Probability threshold for top-p sampling.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Sampled token indices.
|
||||
|
||||
Note:
|
||||
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
||||
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
||||
"""
|
||||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = probs_sum - probs_sort > p
|
||||
probs_sort[mask] = 0.0
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
442
llama_stack/models/llama/llama4/model.py
Normal file
442
llama_stack/models/llama/llama4/model.py
Normal file
|
@ -0,0 +1,442 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.layers import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from .args import ModelArgs
|
||||
from .datatypes import TransformerInput, TransformerOutput
|
||||
from .ffn import FeedForward
|
||||
from .moe import MoE
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
|
||||
|
||||
class L2Norm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
def apply_scaling(freqs: torch.Tensor):
|
||||
# Values obtained from grid search
|
||||
scale_factor = 8
|
||||
low_freq_factor = 1
|
||||
high_freq_factor = 4
|
||||
old_context_len = 8192 # original llama3 length
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
new_freqs = []
|
||||
for freq in freqs:
|
||||
wavelen = 2 * math.pi / freq
|
||||
if wavelen < high_freq_wavelen:
|
||||
new_freqs.append(freq)
|
||||
elif wavelen > low_freq_wavelen:
|
||||
new_freqs.append(freq / scale_factor)
|
||||
else:
|
||||
assert low_freq_wavelen != high_freq_wavelen
|
||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
|
||||
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
||||
if use_scaled:
|
||||
freqs = apply_scaling(freqs)
|
||||
freqs = torch.outer(t, freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
# TODO: this module needs to be moved into a separate file since it can be used by
|
||||
# the vision encoder as well.
|
||||
def __init__(
|
||||
self,
|
||||
args: ModelArgs,
|
||||
use_qk_norm: bool,
|
||||
use_rope: bool,
|
||||
add_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_rope = use_rope
|
||||
self.use_qk_norm = use_qk_norm
|
||||
# For attention temperature tuning
|
||||
self.attn_temperature_tuning = args.attn_temperature_tuning
|
||||
self.floor_scale = args.floor_scale
|
||||
self.attn_scale = args.attn_scale
|
||||
|
||||
self.n_heads = args.n_heads
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
world_size = fs_init.get_model_parallel_world_size()
|
||||
self.n_local_heads = args.n_heads // world_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // world_size
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
|
||||
self.wq = ColumnParallelLinear(
|
||||
args.dim,
|
||||
args.n_heads * self.head_dim,
|
||||
bias=add_bias,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wk = ColumnParallelLinear(
|
||||
args.dim,
|
||||
self.n_kv_heads * self.head_dim,
|
||||
bias=add_bias,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wv = ColumnParallelLinear(
|
||||
args.dim,
|
||||
self.n_kv_heads * self.head_dim,
|
||||
bias=add_bias,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wo = RowParallelLinear(
|
||||
args.n_heads * self.head_dim,
|
||||
args.dim,
|
||||
bias=add_bias,
|
||||
input_is_parallel=True,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
|
||||
self.cache_k = torch.zeros(
|
||||
(
|
||||
args.max_batch_size,
|
||||
args.max_seq_len,
|
||||
self.n_local_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
).cuda()
|
||||
self.cache_v = torch.zeros(
|
||||
(
|
||||
args.max_batch_size,
|
||||
args.max_seq_len,
|
||||
self.n_local_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
).cuda()
|
||||
self.qk_norm = None
|
||||
if self.use_qk_norm:
|
||||
self.qk_norm = L2Norm(args.norm_eps)
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "wqkv.weight" in state_dict:
|
||||
wqkv = state_dict.pop(prefix + "wqkv.weight")
|
||||
d, r = divmod(wqkv.shape[0], self.n_heads + 2 * self.n_kv_heads)
|
||||
if r != 0:
|
||||
raise ValueError(
|
||||
f"shape={tuple(wqkv.shape)} is not divisible by "
|
||||
f"n_heads ({self.n_heads}) + 2 * n_kv_heads ({self.n_kv_heads})"
|
||||
)
|
||||
wq, wk, wv = wqkv.split([d * self.n_heads, d * self.n_kv_heads, d * self.n_kv_heads], dim=0)
|
||||
state_dict[prefix + "wq.weight"] = wq
|
||||
state_dict[prefix + "wk.weight"] = wk
|
||||
state_dict[prefix + "wv.weight"] = wv
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
bsz, seqlen, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
if self.use_rope:
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
if self.use_qk_norm:
|
||||
xq = self.qk_norm(xq)
|
||||
xk = self.qk_norm(xk)
|
||||
|
||||
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
|
||||
# the inference-time temperature tuning function is customized to not affect short context
|
||||
# while working at very long context
|
||||
if self.attn_temperature_tuning and not self.use_rope:
|
||||
seq_positions = torch.arange(start_pos, start_pos + seqlen, device=xq.device, dtype=torch.float32)
|
||||
attn_scales = torch.log(torch.floor((seq_positions + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0
|
||||
|
||||
# reshape for broadcasting [seqlen] -> [1, seqlen, 1, 1]
|
||||
attn_scales = attn_scales.view(1, seqlen, 1, 1)
|
||||
xq = xq * attn_scales
|
||||
|
||||
self.cache_k = self.cache_k.to(xq)
|
||||
self.cache_v = self.cache_v.to(xq)
|
||||
|
||||
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
|
||||
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
|
||||
|
||||
xk = self.cache_k[:bsz, : start_pos + seqlen]
|
||||
xv = self.cache_v[:bsz, : start_pos + seqlen]
|
||||
|
||||
xq, xk, xv = [t.transpose(1, 2) for t in (xq, xk, xv)]
|
||||
|
||||
xk = xk.repeat_interleave(self.n_rep, dim=1)
|
||||
xv = xv.repeat_interleave(self.n_rep, dim=1)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask, dropout_p=0.0)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||
output = self.wo(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_heads = args.n_heads
|
||||
self.dim = args.dim
|
||||
self.head_dim = args.dim // args.n_heads if args.head_dim is None else args.head_dim
|
||||
|
||||
self.is_nope_layer = args.nope_layer_interval is not None and (layer_id + 1) % args.nope_layer_interval == 0
|
||||
|
||||
use_rope = not self.is_nope_layer
|
||||
use_qk_norm = args.use_qk_norm and not self.is_nope_layer
|
||||
|
||||
self.attention = Attention(args, use_rope=use_rope, use_qk_norm=use_qk_norm)
|
||||
|
||||
if args.moe_args and (layer_id + 1) % args.moe_args.interleave_moe_layer_step == 0:
|
||||
self.feed_forward = MoE(
|
||||
dim=args.dim,
|
||||
hidden_dim=int(args.ffn_exp * args.dim),
|
||||
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||
multiple_of=args.multiple_of,
|
||||
moe_args=args.moe_args,
|
||||
)
|
||||
else:
|
||||
hidden_dim = int(4 * args.dim)
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
if args.ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)
|
||||
|
||||
self.feed_forward = FeedForward(
|
||||
dim=args.dim,
|
||||
hidden_dim=hidden_dim,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "attention.wqkv.layer_norm_weight" in state_dict:
|
||||
state_dict[prefix + "attention_norm.weight"] = state_dict.pop(prefix + "attention.wqkv.layer_norm_weight")
|
||||
|
||||
if prefix + "feed_forward.mlp.layer_norm_weight" in state_dict:
|
||||
state_dict[prefix + "ffn_norm.weight"] = state_dict.pop(prefix + "feed_forward.mlp.layer_norm_weight")
|
||||
elif prefix + "feed_forward.norm.weight" in state_dict:
|
||||
state_dict[prefix + "ffn_norm.weight"] = state_dict.pop(prefix + "feed_forward.norm.weight")
|
||||
|
||||
for k in (
|
||||
"feed_forward.experts.mlp",
|
||||
"feed_forward.mlp_shared",
|
||||
"attention.wo",
|
||||
"attention.wqkv",
|
||||
):
|
||||
if prefix + k + "._extra_state" in state_dict:
|
||||
state_dict.pop(prefix + k + "._extra_state")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
global_attn_mask: Optional[torch.Tensor],
|
||||
local_attn_mask: Optional[torch.Tensor],
|
||||
):
|
||||
# The iRoPE architecture uses global attention mask for NoPE layers or
|
||||
# if chunked local attention is not used
|
||||
if self.is_nope_layer or local_attn_mask is None:
|
||||
mask = global_attn_mask
|
||||
else:
|
||||
mask = local_attn_mask
|
||||
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, args: ModelArgs, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.vocab_size = args.vocab_size
|
||||
self.n_layers = args.n_layers
|
||||
|
||||
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
|
||||
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for layer_id in range(args.n_layers):
|
||||
self.layers.append(TransformerBlock(layer_id, args))
|
||||
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.output = ColumnParallelLinear(args.dim, args.vocab_size, bias=False, init_method=lambda x: x)
|
||||
|
||||
self.freqs_cis = precompute_freqs_cis(
|
||||
args.dim // args.n_heads,
|
||||
args.max_seq_len * 2,
|
||||
args.rope_theta,
|
||||
args.use_scaled_rope,
|
||||
)
|
||||
vision_args = self.args.vision_args
|
||||
if vision_args:
|
||||
# circular import otherwise until we refactor out Attention
|
||||
from .vision.embedding import VisionEmbeddings
|
||||
|
||||
self.vision_embeddings = VisionEmbeddings(vision_args)
|
||||
self.vision_projection = ColumnParallelLinear(
|
||||
vision_args.output_dim,
|
||||
args.dim,
|
||||
bias=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "rope.freqs" in state_dict:
|
||||
state_dict.pop(prefix + "rope.freqs")
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, model_input: TransformerInput) -> TransformerOutput:
|
||||
tokens = model_input.tokens
|
||||
start_pos = model_input.tokens_position
|
||||
assert isinstance(start_pos, int), (
|
||||
"This implementation does not support different start positions per batch item"
|
||||
)
|
||||
|
||||
_bsz, seqlen = tokens.shape
|
||||
h = self.tok_embeddings(tokens)
|
||||
|
||||
if image_embedding := model_input.image_embedding:
|
||||
h_image = self.vision_projection(image_embedding.embedding)
|
||||
h = h * ~image_embedding.mask + h_image * image_embedding.mask
|
||||
|
||||
self.freqs_cis = self.freqs_cis.to(h.device)
|
||||
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
||||
|
||||
global_attn_mask, local_attn_mask = None, None
|
||||
if seqlen > 1:
|
||||
global_attn_mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
|
||||
global_attn_mask = torch.triu(global_attn_mask, diagonal=1).type_as(h)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/100005
|
||||
# torch.triu is buggy when the device is mps: filled values are
|
||||
# nan instead of 0.
|
||||
if global_attn_mask.device.type == torch.device("mps").type:
|
||||
global_attn_mask = torch.nan_to_num(global_attn_mask, nan=0.0)
|
||||
|
||||
if chunk_size := self.args.attention_chunk_size:
|
||||
local_attn_mask = create_chunked_attention_mask(seqlen, chunk_size, tokens.device)
|
||||
|
||||
for layer in self.layers:
|
||||
h = layer(h, start_pos, freqs_cis, global_attn_mask, local_attn_mask)
|
||||
h = self.norm(h)
|
||||
output = self.output(h).float()
|
||||
|
||||
return TransformerOutput(logits=output)
|
||||
|
||||
|
||||
# tokens (0, K), (K, 2K), (2K, 3K) attend to each other when doing local chunked attention
|
||||
# in the iRoPE architecture
|
||||
def create_chunked_attention_mask(seq_len: int, attention_chunk_size: int, device: torch.device) -> torch.Tensor:
|
||||
block_pos = torch.abs(
|
||||
(torch.arange(seq_len).unsqueeze(0) // attention_chunk_size)
|
||||
- (torch.arange(seq_len).unsqueeze(1) // attention_chunk_size)
|
||||
)
|
||||
token_pos = torch.arange(seq_len).unsqueeze(0) - torch.arange(seq_len).unsqueeze(1)
|
||||
mask = (block_pos == 0) & (token_pos <= 0)
|
||||
return mask.to(device)
|
214
llama_stack/models/llama/llama4/moe.py
Normal file
214
llama_stack/models/llama/llama4/moe.py
Normal file
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# ruff: noqa: N806
|
||||
# pyre-strict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .args import MoEArgs
|
||||
from .ffn import FeedForward
|
||||
|
||||
|
||||
class Experts(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_local_experts: int,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
self.num_local_experts = num_local_experts
|
||||
self.dim = dim
|
||||
divide_factor = fs_init.get_model_parallel_world_size()
|
||||
|
||||
self.w1: nn.Parameter = nn.Parameter(
|
||||
torch.empty(
|
||||
num_local_experts,
|
||||
dim,
|
||||
divide_exact(hidden_dim, divide_factor),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self.w2: nn.Parameter = nn.Parameter(
|
||||
torch.empty(
|
||||
num_local_experts,
|
||||
divide_exact(hidden_dim, divide_factor),
|
||||
dim,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self.w3: nn.Parameter = nn.Parameter(
|
||||
torch.empty(
|
||||
num_local_experts,
|
||||
dim,
|
||||
divide_exact(hidden_dim, divide_factor),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
self.prefix = prefix
|
||||
if prefix + "moe_w_in_eD_F" in state_dict:
|
||||
e = self.num_local_experts
|
||||
D = self.dim
|
||||
state_dict[prefix + "w1"] = state_dict.pop(prefix + "moe_w_in_eD_F").view(e, D, -1)
|
||||
state_dict[prefix + "w2"] = state_dict.pop(prefix + "moe_w_out_eF_D").view(e, -1, D)
|
||||
state_dict[prefix + "w3"] = state_dict.pop(prefix + "moe_w_swiglu_eD_F").view(e, D, -1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
routed_in_egD: torch.Tensor, # noqa: N803
|
||||
) -> torch.Tensor:
|
||||
e = self.num_local_experts
|
||||
D = self.dim
|
||||
|
||||
x_egD = routed_in_egD.view(e, -1, D)
|
||||
|
||||
out_egD = self.batched_swiglu(x_egD, self.w1, self.w3, self.w2)
|
||||
out_egD = out_egD.view(-1, D)
|
||||
|
||||
return out_egD
|
||||
|
||||
def batched_swiglu(self, x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
|
||||
middle_out_egF = F.silu(torch.bmm(x, w1)) * torch.bmm(x, w3)
|
||||
return torch.bmm(middle_out_egF, w2)
|
||||
|
||||
|
||||
class MoE(torch.nn.Module):
|
||||
"""
|
||||
Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor.
|
||||
Several commonly used annotations include:
|
||||
- a: bsz*slen
|
||||
- E: number of experts
|
||||
- e: number of local experts per ep (n_experts/ep)
|
||||
- D: hidden dimension
|
||||
- d: D/tp
|
||||
- F: model dimension
|
||||
- G: number of tokens per expert (a * capacity_factor / E)
|
||||
- g: number of tokens per expert per TP rank (i.e., G/TP)
|
||||
|
||||
Examples:
|
||||
x_aD [a, D]
|
||||
routed_in_etG_D [et*G, D]
|
||||
x_eGD: [e, G, D]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
ffn_dim_multiplier: float,
|
||||
multiple_of: int,
|
||||
moe_args: MoEArgs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.moe_args = moe_args
|
||||
|
||||
hidden_dim_denom: float = 1
|
||||
if moe_args.auto_scale_F:
|
||||
hidden_dim_denom = moe_args.capacity_factor + 1
|
||||
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
|
||||
# custom dim factor multiplier
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
|
||||
if moe_args.auto_scale_F:
|
||||
hidden_dim = int(hidden_dim / hidden_dim_denom)
|
||||
|
||||
hidden_dim += -hidden_dim % multiple_of
|
||||
|
||||
num_local_experts: int = moe_args.num_experts
|
||||
dtype: torch.dtype = torch.get_default_dtype()
|
||||
self.experts = Experts(
|
||||
num_local_experts,
|
||||
dim,
|
||||
hidden_dim,
|
||||
)
|
||||
|
||||
self.router_DE: nn.Parameter = nn.Parameter(torch.empty(dim, moe_args.num_experts, dtype=dtype))
|
||||
self.shared_expert = FeedForward(dim, hidden_dim, do_reduce=False)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "w_in_shared_FD.weight" in state_dict:
|
||||
state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight")
|
||||
state_dict[prefix + "shared_expert.w3.weight"] = state_dict.pop(prefix + "w_swiglu_FD.weight")
|
||||
state_dict[prefix + "shared_expert.w2.weight"] = state_dict.pop(prefix + "w_out_shared_DF.weight")
|
||||
|
||||
def forward(self, x_bsD: Tensor) -> Tensor: # noqa: N803
|
||||
_, slen, D = x_bsD.shape
|
||||
x_aD = x_bsD.view(-1, D)
|
||||
|
||||
a = x_aD.shape[0]
|
||||
|
||||
router_scores: Tensor = torch.matmul(x_aD, self.router_DE).transpose(0, 1)
|
||||
|
||||
router_scores_aK, router_indices_aK = torch.topk(router_scores.transpose(0, 1), self.moe_args.top_k, dim=1)
|
||||
router_scores = (
|
||||
torch.full_like(router_scores.transpose(0, 1), float("-inf"))
|
||||
.scatter_(1, router_indices_aK, router_scores_aK)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
router_indices = torch.arange(a, device=x_aD.device).view(1, -1).expand(router_scores.size(0), -1)
|
||||
|
||||
router_scores = torch.sigmoid(router_scores)
|
||||
|
||||
routed_in_EG_D: Tensor = torch.gather(
|
||||
x_aD,
|
||||
dim=0,
|
||||
index=router_indices.reshape(-1, 1).expand(-1, D),
|
||||
)
|
||||
routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)
|
||||
|
||||
out_aD = self.shared_expert(x_aD)
|
||||
routed_out_eg_D = self.experts(routed_in_EG_D.detach())
|
||||
|
||||
router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
|
||||
out_aD.scatter_add_(
|
||||
dim=0,
|
||||
index=router_indices_EG_D,
|
||||
src=routed_out_eg_D.view(-1, D),
|
||||
)
|
||||
out_aD = reduce_from_model_parallel_region(out_aD)
|
||||
return out_aD.view(-1, slen, D)
|
||||
|
||||
|
||||
def divide_exact(numerator: int, denominator: int) -> int:
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
|
||||
return numerator // denominator
|
436
llama_stack/models/llama/llama4/preprocess.py
Normal file
436
llama_stack/models/llama/llama4/preprocess.py
Normal file
|
@ -0,0 +1,436 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv
|
||||
from PIL import Image, ImageFile
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
IMAGE_RES = 448
|
||||
|
||||
|
||||
class ResizeNormalizeImageTransform:
|
||||
def __init__(
|
||||
self,
|
||||
size_width=None,
|
||||
size_height=None,
|
||||
) -> None:
|
||||
self._size_width = size_width or IMAGE_RES
|
||||
self._size_height = size_height or IMAGE_RES
|
||||
self._mean = (0.5, 0.5, 0.5)
|
||||
self._std = (0.5, 0.5, 0.5)
|
||||
|
||||
self.tv_transform = tv.Compose(
|
||||
[
|
||||
tv.Resize((self._size_height, self._size_width)),
|
||||
tv.ToTensor(),
|
||||
tv.Normalize(
|
||||
mean=self._mean,
|
||||
std=self._std,
|
||||
inplace=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(self, image: Image.Image) -> torch.Tensor:
|
||||
return self.tv_transform(image)
|
||||
|
||||
|
||||
class VariableSizeImageTransform(object):
|
||||
"""
|
||||
This class accepts images of any size and dynamically resize, pads and chunks it
|
||||
based on the image aspect ratio and the number of image chunks we allow.
|
||||
|
||||
The algorithm will NOT distort the image fit a certain aspect ratio, because
|
||||
that leads to a significant degradation in image quality.
|
||||
|
||||
It can be summarized in 6 steps:
|
||||
1. Find all possible canvas combinations of max_num_chunks;
|
||||
2. Find the best canvas to fit the image;
|
||||
3. Resize without distortion
|
||||
4. Pad
|
||||
5. Normalize
|
||||
6. Chunk
|
||||
|
||||
For example, if an input image is of size 300x800, patch_size of 224,
|
||||
and max_num_chunks = 8, it will find the closest aspect ratio that
|
||||
is allowed within 8 image chunks, with some restrictions.
|
||||
In this case, 2:4 = 2 horizontal patches and 4 vertical patches,
|
||||
giving a total of 8 chunks.
|
||||
|
||||
If resize_to_max_canvas, the image will be resized (without distortion),
|
||||
to the largest possible resolution. In this case, 388:896, and padded to 448:896,
|
||||
where we maintain the original aspect ratio and pad with zeros value for the rest.
|
||||
This approach minimizes the amount of padding required for any arbitrary resolution.
|
||||
|
||||
However, if limit_upscaling_to_patch_size is set to True,
|
||||
the upscaling will be limited to the patch size. In the example above,
|
||||
the image would remain 300x800 (no upscaling), and then padded to 448:896.
|
||||
|
||||
The final output will therefore be of shape (8, 3, 224, 224), where 2x4
|
||||
patches are coming from the resizing and chunking.
|
||||
"""
|
||||
|
||||
def __init__(self, size: int = IMAGE_RES) -> None:
|
||||
self.size = size
|
||||
self.to_tensor = tv.ToTensor()
|
||||
self._mean = (0.5, 0.5, 0.5)
|
||||
self._std = (0.5, 0.5, 0.5)
|
||||
self.normalize = tv.Normalize(
|
||||
mean=self._mean,
|
||||
std=self._std,
|
||||
inplace=True,
|
||||
)
|
||||
self.resample = tv.InterpolationMode.BILINEAR
|
||||
|
||||
@staticmethod
|
||||
def get_factors(n: int) -> Set[int]:
|
||||
"""
|
||||
Calculate all factors of a given number, i.e. a dividor that leaves
|
||||
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
||||
|
||||
Args:
|
||||
n (int): The number to find factors for.
|
||||
|
||||
Returns:
|
||||
set: A set containing all factors of the number.
|
||||
"""
|
||||
factors_set = set()
|
||||
|
||||
for i in range(1, int(n**0.5) + 1):
|
||||
if n % i == 0:
|
||||
factors_set.add(i)
|
||||
factors_set.add(n // i)
|
||||
return factors_set
|
||||
|
||||
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Computes all of the allowed resoltuions for a fixed number of chunks
|
||||
and patch_size. Useful for when dividing an image into chunks.
|
||||
|
||||
Args:
|
||||
max_num_chunks (int): Maximum number of chunks for processing.
|
||||
patch_size (int): Size of the side of the patch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: List of possible resolutions as tuples (height, width).
|
||||
|
||||
Example:
|
||||
>>> max_num_chunks = 5
|
||||
>>> patch_size = 224
|
||||
>>> find_supported_resolutions(max_num_chunks, patch_size)
|
||||
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
|
||||
(672, 224), (224, 448), (448, 224)])
|
||||
|
||||
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
|
||||
{
|
||||
0.25: [(1, 4)],
|
||||
1.0: [(2, 2), (1, 1)],
|
||||
4.0: [(4, 1)],
|
||||
0.33: [(1, 3)],
|
||||
3.0: [(3, 1)],
|
||||
0.5: [(1, 2)],
|
||||
2.0: [(2, 1)]
|
||||
}
|
||||
|
||||
and return the resolutions multiplied by the patch_size:
|
||||
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
|
||||
"""
|
||||
asp_dict = defaultdict(list)
|
||||
for chunk_size in range(max_num_chunks, 0, -1):
|
||||
_factors = sorted(self.get_factors(chunk_size))
|
||||
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
|
||||
for height, width in _asp_ratios:
|
||||
ratio_float = height / width
|
||||
asp_dict[ratio_float].append((height, width))
|
||||
|
||||
# get the resolutions multiplied by the patch_size
|
||||
possible_resolutions = []
|
||||
for value in asp_dict.values():
|
||||
for height, width in value:
|
||||
possible_resolutions.append((height * patch_size, width * patch_size))
|
||||
|
||||
return possible_resolutions
|
||||
|
||||
@staticmethod
|
||||
def get_max_res_without_distortion(
|
||||
image_size: Tuple[int, int],
|
||||
target_size: Tuple[int, int],
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the maximum resolution to which an image can be resized to without distorting its
|
||||
aspect ratio, based on the target resolution.
|
||||
|
||||
Args:
|
||||
image_size (Tuple[int, int]): The original resolution of the image (height, width).
|
||||
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
|
||||
Returns:
|
||||
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
|
||||
Example:
|
||||
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
|
||||
(134, 200)
|
||||
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
|
||||
(450, 338)
|
||||
"""
|
||||
|
||||
original_width, original_height = image_size
|
||||
target_width, target_height = target_size
|
||||
|
||||
scale_w = target_width / original_width
|
||||
scale_h = target_height / original_height
|
||||
|
||||
if scale_w < scale_h:
|
||||
new_width = target_width
|
||||
new_height = min(math.floor(original_height * scale_w), target_height)
|
||||
else:
|
||||
new_height = target_height
|
||||
new_width = min(math.floor(original_width * scale_h), target_width)
|
||||
|
||||
return new_width, new_height
|
||||
|
||||
def _pad(self, image: Image.Image, target_size) -> Image.Image:
|
||||
new_width, new_height = target_size
|
||||
new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
|
||||
new_im.paste(image)
|
||||
return new_im
|
||||
|
||||
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
|
||||
# Split image into number of required tiles (width x height)
|
||||
num_channels, height, width = image.size()
|
||||
image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
|
||||
# Permute dimensions to reorder the axes
|
||||
image = image.permute(1, 3, 0, 2, 4).contiguous()
|
||||
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
|
||||
image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
|
||||
return image
|
||||
|
||||
def resize_without_distortion(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
target_size: Tuple[int, int],
|
||||
max_upscaling_size: Optional[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Used to resize an image to target_resolution, without distortion.
|
||||
|
||||
If target_size requires upscaling the image, the user can set max_upscaling_size to
|
||||
limit the upscaling to a maximum size. In this case, since we rescale without distortion,
|
||||
modifying target_size works as a boundary for the image's largest side.
|
||||
|
||||
Args:
|
||||
resample (str): Resampling method used when resizing images.
|
||||
Supports "nearest", "nearest_exact", "bilinear", "bicubic".
|
||||
max_upscaling_size (int): The maximum size to upscale the image to.
|
||||
If None, there is no limit.
|
||||
Examples:
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 600
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(600, 300) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 600
|
||||
>>> image_size = (2000, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 100) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 2000
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 500) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = None
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 500) # new_size_without_distortion
|
||||
"""
|
||||
|
||||
image_width, image_height = image.size
|
||||
image_size = (image_width, image_height)
|
||||
|
||||
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
|
||||
if max_upscaling_size is not None:
|
||||
new_target_width = min(max(image_width, max_upscaling_size), target_size[0])
|
||||
new_target_height = min(max(image_height, max_upscaling_size), target_size[1])
|
||||
target_size = (new_target_width, new_target_height)
|
||||
|
||||
# resize to target_size while preserving aspect ratio
|
||||
new_size_without_distortion = self.get_max_res_without_distortion(image_size, target_size)
|
||||
|
||||
image = F.resize(
|
||||
image,
|
||||
(
|
||||
max(new_size_without_distortion[1], 1),
|
||||
max(new_size_without_distortion[0], 1),
|
||||
),
|
||||
interpolation=self.resample,
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
def get_best_fit(
|
||||
self,
|
||||
image_size: Tuple[int, int],
|
||||
possible_resolutions: torch.Tensor,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
||||
resize an image to.
|
||||
|
||||
For each possible resolution, calculates the scaling factors for
|
||||
width and height, and selects the smallest one, which is the limiting side.
|
||||
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
|
||||
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
|
||||
|
||||
If upscaling is possible (any of the scaling factors is greater than 1),
|
||||
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
|
||||
|
||||
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
|
||||
reduce downscaling as much as possible.
|
||||
|
||||
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
|
||||
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
|
||||
has more padding.
|
||||
|
||||
Args:
|
||||
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
|
||||
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
|
||||
row represents a possible resolution (height, width).
|
||||
use_max_upscaling (bool): If True, will return the largest upscaling resolution.
|
||||
|
||||
Returns:
|
||||
List[int]: The best resolution [height, width] for the given image.
|
||||
|
||||
Example:
|
||||
>>> image_size = (200, 300)
|
||||
>>> possible_resolutions = torch.tensor([[224, 672],
|
||||
... [672, 224],
|
||||
... [224, 448],
|
||||
... [448, 224],
|
||||
... [224, 224]])
|
||||
>>> _get_smallest_upscaling_possibility(image_size, possible_resolutions)
|
||||
[224, 448]
|
||||
|
||||
We have:
|
||||
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
|
||||
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
|
||||
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
|
||||
Only one of the scales > 1:
|
||||
upscaling_possible = tensor([1.1200, 1.1200])
|
||||
smallest_rescale = tensor(1.1200)
|
||||
So we pick the resolution with the smallest smallest area:
|
||||
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
|
||||
optimal_canvas = tensor([224, 448])
|
||||
"""
|
||||
|
||||
original_width, original_height = image_size
|
||||
|
||||
# get all possible resolutions heights/widths
|
||||
target_widths, target_heights = (
|
||||
possible_resolutions[:, 0],
|
||||
possible_resolutions[:, 1],
|
||||
)
|
||||
|
||||
# get scaling factors to resize the image without distortion
|
||||
scale_w = target_widths / original_width
|
||||
scale_h = target_heights / original_height
|
||||
|
||||
# get the min scale between width and height (limiting side -> no distortion)
|
||||
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
|
||||
|
||||
# filter only scales that allow upscaling
|
||||
upscaling_options = scales[scales >= 1]
|
||||
if len(upscaling_options) > 0:
|
||||
if resize_to_max_canvas:
|
||||
selected_scale = torch.max(upscaling_options)
|
||||
else:
|
||||
selected_scale = torch.min(upscaling_options)
|
||||
else:
|
||||
# no upscaling possible,
|
||||
# get the minimum downscaling (max scale for scales<1)
|
||||
downscaling_options = scales[scales < 1]
|
||||
selected_scale = torch.max(downscaling_options)
|
||||
|
||||
# get all resolutions that support this scaling factor,
|
||||
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
|
||||
chosen_canvas = possible_resolutions[scales == selected_scale]
|
||||
|
||||
# if there are multiple resolutions,
|
||||
# get the one with minimum area to reduce padding
|
||||
if len(chosen_canvas) > 1:
|
||||
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
|
||||
optimal_idx = torch.argmin(areas)
|
||||
optimal_canvas = chosen_canvas[optimal_idx]
|
||||
else:
|
||||
optimal_canvas = chosen_canvas[0]
|
||||
|
||||
return tuple(optimal_canvas.tolist())
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
image: Image.Image,
|
||||
max_num_chunks: int,
|
||||
normalize_img: bool = True,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
"""
|
||||
Args:
|
||||
image (PIL.Image): Image to be resized.
|
||||
max_num_chunks (int): Maximum number of chunks to split the image into.
|
||||
normalize_img (bool): Whether to normalize the image.
|
||||
resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size.
|
||||
If True, picks the canvas the allows the largest resizing without distortion.
|
||||
If False, downsample as little as possible, including no resizing at all,
|
||||
but never upsample, unless the image is smaller than the patch size.
|
||||
"""
|
||||
assert max_num_chunks > 0
|
||||
assert isinstance(image, Image.Image), type(image)
|
||||
w, h = image.size
|
||||
|
||||
possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
|
||||
possible_resolutions = torch.tensor(possible_resolutions)
|
||||
|
||||
best_resolution = self.get_best_fit(
|
||||
image_size=(w, h),
|
||||
possible_resolutions=possible_resolutions,
|
||||
resize_to_max_canvas=resize_to_max_canvas,
|
||||
)
|
||||
|
||||
max_upscaling_size = None if resize_to_max_canvas else self.size
|
||||
image = self.resize_without_distortion(image, best_resolution, max_upscaling_size)
|
||||
image = self._pad(image, best_resolution)
|
||||
|
||||
image = self.to_tensor(image)
|
||||
|
||||
if normalize_img:
|
||||
image = self.normalize(image)
|
||||
|
||||
ratio_w, ratio_h = (
|
||||
best_resolution[0] // self.size,
|
||||
best_resolution[1] // self.size,
|
||||
)
|
||||
|
||||
image = self._split(image, ratio_w, ratio_h) # type: ignore
|
||||
|
||||
ar = (ratio_h, ratio_w)
|
||||
return image, ar
|
|
@ -4,20 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from llama_stack.models.llama.datatypes import RawMediaItem, RawMessage, RawTextItem
|
||||
from llama_stack.models.llama.prompt_format import (
|
||||
from ..datatypes import RawMediaItem, RawMessage, RawTextItem
|
||||
from ..prompt_format import (
|
||||
Llama4UseCase,
|
||||
TextCompletionContent,
|
||||
UseCase,
|
||||
|
|
5
llama_stack/models/llama/llama4/quantization/__init__.py
Normal file
5
llama_stack/models/llama/llama4/quantization/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
225
llama_stack/models/llama/llama4/quantization/loader.py
Normal file
225
llama_stack/models/llama/llama4/quantization/loader.py
Normal file
|
@ -0,0 +1,225 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ...datatypes import QuantizationMode
|
||||
from ..model import Transformer, TransformerBlock
|
||||
from ..moe import MoE
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def swiglu_wrapper_no_reduce(
|
||||
self,
|
||||
x: Tensor,
|
||||
):
|
||||
from ...quantize_impls import ffn_swiglu
|
||||
|
||||
return ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
||||
|
||||
|
||||
def experts_batched_swiglu_wrapper(
|
||||
self,
|
||||
x: Tensor, # (e, g, D)
|
||||
w1: Tensor, # (e, D, F)
|
||||
w3: Tensor, # (e, D, F)
|
||||
w2: Tensor, # (e, F, D)
|
||||
) -> torch.Tensor:
|
||||
from ...quantize_impls import bmm_nt
|
||||
|
||||
middle_out_egF = F.silu(bmm_nt(x, w1)) * bmm_nt(x, w3) # noqa: N806
|
||||
return bmm_nt(middle_out_egF, w2)
|
||||
|
||||
|
||||
def convert_to_quantized_model(
|
||||
model: Transformer,
|
||||
checkpoint_dir: str,
|
||||
quantization_mode: Optional[str] = None,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
use_rich_progress: bool = True,
|
||||
) -> Transformer:
|
||||
from ...quantize_impls import (
|
||||
Fp8ScaledWeights,
|
||||
Int4ScaledWeights,
|
||||
load_fp8,
|
||||
load_int4,
|
||||
quantize_fp8,
|
||||
quantize_int4,
|
||||
)
|
||||
|
||||
rank = get_model_parallel_rank()
|
||||
|
||||
def should_quantize_block(block: nn.Module) -> bool:
|
||||
if not isinstance(block, TransformerBlock):
|
||||
return False
|
||||
|
||||
is_moe = isinstance(block.feed_forward, MoE)
|
||||
if quantization_mode == QuantizationMode.fp8_mixed:
|
||||
# skip quantization on first and last layers
|
||||
return is_moe and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
|
||||
|
||||
return is_moe
|
||||
|
||||
use_rich_progress = use_rich_progress and rank == 0
|
||||
progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model, should_quantize_block)
|
||||
if quantization_mode == QuantizationMode.int4_mixed:
|
||||
int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt")
|
||||
if os.path.isfile(int4_scales_path):
|
||||
log_status(f"Rank {rank}: Loading int4 scales")
|
||||
int4_scales = torch.load(int4_scales_path, weights_only=True)
|
||||
|
||||
def apply_quantization(key, weight):
|
||||
scale = int4_scales[key]
|
||||
return load_int4(
|
||||
weight,
|
||||
scale,
|
||||
output_device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
else:
|
||||
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
|
||||
|
||||
def apply_quantization(_, weight):
|
||||
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
|
||||
|
||||
else:
|
||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
|
||||
if os.path.isfile(fp8_scales_path):
|
||||
log_status(f"Rank {rank}: Loading fp8 scales")
|
||||
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
||||
|
||||
def apply_quantization(key, weight):
|
||||
scale = fp8_scales[key]
|
||||
return load_fp8(
|
||||
weight,
|
||||
scale,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
else:
|
||||
log_status(f"Rank {rank}: Quantizing fp8 weights from bf16")
|
||||
|
||||
def apply_quantization(_, weight):
|
||||
return quantize_fp8(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
|
||||
|
||||
processed_blocks = 0
|
||||
try:
|
||||
if use_rich_progress:
|
||||
progress.start()
|
||||
|
||||
for _, block in model.named_modules():
|
||||
if not should_quantize_block(block):
|
||||
continue
|
||||
|
||||
update_status(f"Rank {rank} - Layer {block.layer_id}")
|
||||
|
||||
# Quantize only routed experts, not shared
|
||||
prefix = f"layers.{block.layer_id}.feed_forward"
|
||||
moe = block.feed_forward
|
||||
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
|
||||
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(moe.experts, key)
|
||||
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
|
||||
setattr(
|
||||
moe.experts,
|
||||
key,
|
||||
apply_quantization(
|
||||
f"{prefix}.experts.{key}",
|
||||
param.transpose(1, 2).contiguous(),
|
||||
),
|
||||
)
|
||||
|
||||
if quantization_mode == QuantizationMode.int4_mixed:
|
||||
# Quantize shared experts
|
||||
moe.shared_expert.forward = swiglu_wrapper_no_reduce.__get__(moe.shared_expert)
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(moe.shared_expert, key)
|
||||
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE shared expert {key}")
|
||||
param.weight = apply_quantization(f"{prefix}.shared_expert.{key}", param.weight)
|
||||
|
||||
processed_blocks += 1
|
||||
update_status(message=None, completed=processed_blocks)
|
||||
|
||||
update_status(f"Rank {rank} - Moving parameters to CUDA")
|
||||
|
||||
param_count = 0
|
||||
for _, parameter in model.named_parameters():
|
||||
if not isinstance(parameter, Fp8ScaledWeights) and not isinstance(parameter, Int4ScaledWeights):
|
||||
parameter.data = parameter.to(device="cuda")
|
||||
param_count += 1
|
||||
|
||||
update_status(f"Rank {rank} - Completed - moved {param_count} parameters to CUDA")
|
||||
finally:
|
||||
if use_rich_progress:
|
||||
progress.stop()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# fp8/int4 loading can be very slow so we add progress bars to make life slightly better
|
||||
def logging_callbacks(
|
||||
use_rich_progress: bool,
|
||||
rank: int,
|
||||
model: Transformer,
|
||||
should_quantize_block: Callable[[nn.Module], bool],
|
||||
):
|
||||
console = None
|
||||
if use_rich_progress:
|
||||
from rich.console import Console
|
||||
|
||||
console = Console(highlight=False)
|
||||
|
||||
def log_status(message: str) -> None:
|
||||
if use_rich_progress:
|
||||
console.print(message)
|
||||
elif rank == 0: # Only log from rank 0 for non-rich logging
|
||||
log.info(message)
|
||||
|
||||
total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block))
|
||||
progress = None
|
||||
if use_rich_progress:
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
)
|
||||
|
||||
progress = Progress(
|
||||
SpinnerColumn(),
|
||||
BarColumn(complete_style="green", finished_style="bright_green"),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TimeElapsedColumn(),
|
||||
TextColumn("ETA:"),
|
||||
TimeRemainingColumn(),
|
||||
TextColumn("[bold]{task.fields[status]}"),
|
||||
console=console,
|
||||
expand=True,
|
||||
)
|
||||
task_id = progress.add_task("[blue]Converting layers...", total=total_blocks, status="Starting")
|
||||
|
||||
def update_status(message: Optional[str], completed: Optional[int] = None) -> None:
|
||||
if use_rich_progress:
|
||||
if message is not None:
|
||||
progress.update(task_id, status=message)
|
||||
if completed is not None:
|
||||
progress.update(task_id, completed=completed)
|
||||
elif rank == 0 and completed and completed % 10 == 0:
|
||||
log.info(f"Rank {rank}: {completed}/{total_blocks} blocks completed")
|
||||
|
||||
return progress, log_status, update_status
|
|
@ -4,9 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import os
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
@ -59,8 +56,6 @@ LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
|
|||
"<|text_post_train_reserved_special_token_3|>",
|
||||
"<|text_post_train_reserved_special_token_4|>",
|
||||
"<|text_post_train_reserved_special_token_5|>",
|
||||
"<|python_start|>",
|
||||
"<|python_end|>",
|
||||
"<|finetune_right_pad|>",
|
||||
] + get_reserved_special_tokens(
|
||||
"text_post_train", 61, 6
|
||||
|
@ -85,8 +80,23 @@ LLAMA4_VISION_SPECIAL_TOKENS = [
|
|||
"vision", 1041, 7
|
||||
) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|>
|
||||
|
||||
# 201134, ..., 201143
|
||||
LLAMA4_REASONING_SPECIAL_TOKENS = [
|
||||
"<|reasoning_reserved_special_token_0|>",
|
||||
"<|reasoning_reserved_special_token_1|>",
|
||||
"<|reasoning_reserved_special_token_2|>",
|
||||
"<|reasoning_reserved_special_token_3|>",
|
||||
"<|reasoning_reserved_special_token_4|>",
|
||||
"<|reasoning_reserved_special_token_5|>",
|
||||
"<|reasoning_reserved_special_token_6|>",
|
||||
"<|reasoning_reserved_special_token_7|>",
|
||||
"<|reasoning_thinking_start|>",
|
||||
"<|reasoning_thinking_end|>",
|
||||
]
|
||||
|
||||
LLAMA4_SPECIAL_TOKENS = LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS
|
||||
LLAMA4_SPECIAL_TOKENS = (
|
||||
LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS + LLAMA4_REASONING_SPECIAL_TOKENS
|
||||
)
|
||||
|
||||
BASIC_SPECIAL_TOKENS = [
|
||||
"<|begin_of_text|>",
|
||||
|
@ -155,6 +165,9 @@ class Tokenizer:
|
|||
self.eot_id: int = self.special_tokens["<|eot|>"]
|
||||
self.eom_id: int = self.special_tokens["<|eom|>"]
|
||||
|
||||
self.thinking_start_id: int = self.special_tokens["<|reasoning_thinking_start|>"]
|
||||
self.thinking_end_id: int = self.special_tokens["<|reasoning_thinking_end|>"]
|
||||
|
||||
self.stop_tokens = [
|
||||
self.eos_id,
|
||||
self.special_tokens["<|eom|>"],
|
||||
|
|
209
llama_stack/models/llama/llama4/vision/embedding.py
Normal file
209
llama_stack/models/llama/llama4/vision/embedding.py
Normal file
|
@ -0,0 +1,209 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
|
||||
from ..args import VisionArgs
|
||||
from .encoder import VisionEncoder
|
||||
|
||||
|
||||
class PixelShuffle(nn.Module):
|
||||
def __init__(self, ps_ratio):
|
||||
super().__init__()
|
||||
self.ps_ratio = ps_ratio
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, N, C], N = number of patches
|
||||
assert self.ps_ratio is not None, "ps_ratio is required for pixel shuffle"
|
||||
assert x.dim() == 3, "pixel shuffle requires encoded patches [B, N, C]"
|
||||
hh = ww = int(math.sqrt(x.shape[1]))
|
||||
x = x.reshape(x.shape[0], hh, ww, -1)
|
||||
x = pixel_shuffle_op(x, ps_ratio=self.ps_ratio)
|
||||
pixel_shuffle_patches = x.reshape(x.shape[0], -1, x.shape[-1])
|
||||
return pixel_shuffle_patches
|
||||
|
||||
|
||||
def pixel_shuffle_op(input_x, ps_ratio):
|
||||
n, w, h, c = input_x.size()
|
||||
input_x = input_x.view(n, w, int(h * ps_ratio), int(c / ps_ratio))
|
||||
input_x = input_x.permute(0, 2, 1, 3).contiguous()
|
||||
input_x = input_x.view(
|
||||
n,
|
||||
int(h * ps_ratio),
|
||||
int(w * ps_ratio),
|
||||
int(c / (ps_ratio * ps_ratio)),
|
||||
)
|
||||
input_x = input_x.permute(0, 2, 1, 3).contiguous()
|
||||
return input_x
|
||||
|
||||
|
||||
class SimpleMLP(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
bias: bool = True,
|
||||
dropout: float = 0.0,
|
||||
act_layer: Callable = nn.GELU,
|
||||
):
|
||||
super().__init__()
|
||||
# layers
|
||||
self.c_fc = ColumnParallelLinear(
|
||||
dim,
|
||||
hidden_dim,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
bias=bias,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
self.non_linearity = act_layer()
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x):
|
||||
hidden = self.c_fc(x)
|
||||
hidden = self.non_linearity(hidden)
|
||||
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
|
||||
return self.non_linearity(self.c_proj(hidden))
|
||||
|
||||
|
||||
class PixelShuffleMLP(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ps_ratio: float,
|
||||
input_dim: int,
|
||||
output_dim: int = 4096,
|
||||
add_fc: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.pixel_shuffle = PixelShuffle(ps_ratio)
|
||||
self.mlp = SimpleMLP(
|
||||
int(input_dim // (ps_ratio**2)),
|
||||
output_dim,
|
||||
bias=False,
|
||||
dropout=0.0,
|
||||
act_layer=nn.GELU,
|
||||
)
|
||||
self.fc = nn.Identity()
|
||||
if add_fc:
|
||||
self.fc = ColumnParallelLinear(
|
||||
output_dim,
|
||||
output_dim,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
|
||||
encoded_patches = self.pixel_shuffle(encoded_patches)
|
||||
return self.fc(self.mlp(encoded_patches))
|
||||
|
||||
|
||||
class VisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, args: VisionArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
image_size = args.image_size
|
||||
patch_size = args.patch_size
|
||||
self.vision_encoder = VisionEncoder(
|
||||
image_size=(image_size.height, image_size.width),
|
||||
patch_size=(patch_size.height, patch_size.width),
|
||||
dim=args.dim,
|
||||
layers=args.n_layers,
|
||||
heads=args.n_heads,
|
||||
mlp_ratio=args.mlp_ratio,
|
||||
)
|
||||
self.vision_encoder = self.vision_encoder.to(torch.bfloat16)
|
||||
self.vision_adapter = PixelShuffleMLP(
|
||||
ps_ratio=args.pixel_shuffle_ratio,
|
||||
input_dim=args.dim,
|
||||
output_dim=args.output_dim,
|
||||
)
|
||||
|
||||
self.output_dim = args.output_dim
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool = True,
|
||||
missing_keys: List[str] = None,
|
||||
unexpected_keys: List[str] = None,
|
||||
error_msgs: List[str] = None,
|
||||
return_state_dict: bool = False,
|
||||
) -> None:
|
||||
original_sd = self.state_dict()
|
||||
for k in state_dict:
|
||||
if k.startswith(prefix) and len(state_dict[k].shape) == 1 and state_dict[k].shape[0] == 0:
|
||||
state_dict[k] = state_dict[k].reshape(original_sd[k[len(prefix) :]].shape)
|
||||
|
||||
def _get_empty_sequence(self, h):
|
||||
return torch.zeros(
|
||||
h.shape[0],
|
||||
h.shape[1],
|
||||
self.output_dim,
|
||||
device=h.device,
|
||||
dtype=h.dtype,
|
||||
)
|
||||
|
||||
# x_images is batched; each batch sample contains a list of images. so this is List[List[torch.Tensor]]
|
||||
# each image is a tensor of shape [num_tiles, C, H, W]
|
||||
def forward(
|
||||
self,
|
||||
image_batch: List[List[torch.Tensor]],
|
||||
image_mask: torch.Tensor,
|
||||
h_ref: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
images_flattened = [image for sample in image_batch for image in sample]
|
||||
images_flattened = torch.vstack(images_flattened).unsqueeze(1).to(h_ref.dtype).to(h_ref.device)
|
||||
embedding = self.vision_encoder(images_flattened)
|
||||
projected_embedding = self.vision_adapter(embedding)
|
||||
|
||||
h_image = self._get_empty_sequence(h_ref)
|
||||
return scatter_embeddings(image_batch, image_mask, h_image, projected_embedding)
|
||||
|
||||
|
||||
def scatter_embeddings(image_batch, image_mask, h_image, encoded_patches_proj):
|
||||
# If dynamic transform is used and the batch contains 2 images (where image_1 has 2 chunks and image_2 has 3 chunks),
|
||||
# `num_images_per_sequence` now records the number of chunks per image as `[2, 3]`.
|
||||
# `encoded_patches_proj.split` will then split the image chunks into 2 groups: `[image_1_chunks, image_2_chunks]`.
|
||||
num_images_per_sequence = [sum(image.size(0) for image in sample_images) for sample_images in image_batch]
|
||||
|
||||
assert not torch.isnan(encoded_patches_proj).any()
|
||||
assert sum(num_images_per_sequence) == encoded_patches_proj.size(0), (
|
||||
f"{sum(num_images_per_sequence)=} != {encoded_patches_proj.shape=}"
|
||||
)
|
||||
|
||||
encoded_patches_list = encoded_patches_proj.split(num_images_per_sequence, dim=0)
|
||||
for index in range(h_image.size(0)):
|
||||
encoded_patches_per_sample = encoded_patches_list[index]
|
||||
sample_image_mask = image_mask[index]
|
||||
|
||||
if encoded_patches_per_sample.numel() == 0:
|
||||
continue
|
||||
encoded_patches_per_sample = encoded_patches_per_sample.contiguous().view(
|
||||
-1, encoded_patches_per_sample.size(-1)
|
||||
)
|
||||
|
||||
n_tokens_to_fill = sample_image_mask.sum()
|
||||
assert n_tokens_to_fill <= encoded_patches_per_sample.size(0)
|
||||
|
||||
h_image[index].masked_scatter_(
|
||||
sample_image_mask.expand(-1, h_image.size(-1)),
|
||||
encoded_patches_per_sample[:n_tokens_to_fill],
|
||||
)
|
||||
|
||||
return h_image
|
411
llama_stack/models/llama/llama4/vision/encoder.py
Normal file
411
llama_stack/models/llama/llama4/vision/encoder.py
Normal file
|
@ -0,0 +1,411 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
from torch import einsum
|
||||
|
||||
from ..args import ModelArgs
|
||||
from ..model import Attention
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
return x
|
||||
|
||||
|
||||
class ColumnParallelConv2dPatch(torch.nn.Module):
|
||||
"""Conv2D Patching layer with model parallelism.
|
||||
Column parallel over unfolded input.
|
||||
Arguments:
|
||||
in_channels: Input channels.
|
||||
out_channels: Output channels.
|
||||
kernel_size: Size of convolution kernel.
|
||||
stride (default 1): Stride for convolution.
|
||||
bias (default False): Use bias in Conv2d.
|
||||
Input: (bsz, in_channels, height, width)
|
||||
Output: (bsz, num_tokens, out_channels)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Union[int, Tuple[int, int]],
|
||||
bias: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
|
||||
self._linear = ColumnParallelLinear(
|
||||
in_channels * kernel_size[0] * kernel_size[1],
|
||||
out_channels,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self._unfold(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self._linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class _FeedForward(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
dropout: float,
|
||||
act_layer: Callable = nn.GELU,
|
||||
):
|
||||
super().__init__()
|
||||
# layers
|
||||
self.c_fc = ColumnParallelLinear(
|
||||
dim,
|
||||
hidden_dim,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
hidden_dim,
|
||||
dim,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.non_linearity = act_layer()
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x):
|
||||
hidden = self.c_fc(x)
|
||||
hidden = self.non_linearity(hidden)
|
||||
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
|
||||
return self.c_proj(hidden)
|
||||
|
||||
|
||||
class _TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_head: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: Callable = nn.GELU,
|
||||
gated: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
assert d_model % n_head == 0
|
||||
self.n_heads = n_head
|
||||
self.head_dim = d_model // self.n_heads
|
||||
|
||||
attn_args = ModelArgs(
|
||||
dim=d_model,
|
||||
head_dim=self.head_dim,
|
||||
n_heads=self.n_heads,
|
||||
n_kv_heads=self.n_heads,
|
||||
)
|
||||
self.attn = Attention(attn_args, use_rope=True, use_qk_norm=False, add_bias=True)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = _FeedForward(
|
||||
dim=d_model,
|
||||
hidden_dim=int(mlp_ratio * d_model),
|
||||
dropout=0.0,
|
||||
act_layer=act_layer,
|
||||
)
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
self.gated = gated
|
||||
if gated:
|
||||
self.gate_attn = nn.Parameter(torch.zeros(1))
|
||||
self.gate_ffn = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def attention(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freq_cis: Optional[torch.Tensor] = None,
|
||||
):
|
||||
return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
freq_cis: Optional[torch.Tensor] = None,
|
||||
):
|
||||
_gate_attn = 1 if not self.gated else self.gate_attn.tanh()
|
||||
_gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
|
||||
|
||||
x = x + _gate_attn * self.attention(self.ln_1(x), freq_cis=freq_cis)
|
||||
x = x + _gate_ffn * self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class _Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: Callable = nn.GELU,
|
||||
gated: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.resblocks = nn.ModuleList(
|
||||
[
|
||||
_TransformerBlock(
|
||||
d_model=dim,
|
||||
n_head=heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
gated=gated,
|
||||
)
|
||||
for _ in range(layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, return_intermediate=None, mask=None, freq_cis=None):
|
||||
out = []
|
||||
for idx, r in enumerate(self.resblocks):
|
||||
if return_intermediate is not None and idx in return_intermediate:
|
||||
out.append(x)
|
||||
x = r(x, mask=mask, freq_cis=freq_cis)
|
||||
if return_intermediate is not None:
|
||||
return x, torch.stack(out, dim=-1)
|
||||
return x
|
||||
|
||||
|
||||
class PackingIndex:
|
||||
Z = 0 # Z (time) coordinate of the token in the original sample
|
||||
Y = 1 # Y (height) coordinate of the token in the original sample
|
||||
X = 2 # X (width) coordinate of the token in the original sample
|
||||
TIME = 3 # Total number of time units (frames) in the original sample
|
||||
HEIGHT = 4 # Height of the original sample
|
||||
WIDTH = 5 # Width of the original sample
|
||||
# USE INDEX TO CHECK THE TYPE OF THE TOKEN (see ID fields below)
|
||||
IDX = 6 # Full index of the token in the original sample (x + y * w + z * w * h)
|
||||
BATCH_IDX = 7 # Which batch element this token belongs to. Note the batch idx of padding tokens is BATCH_SIZE
|
||||
|
||||
# Total size of the enum, remember to update this!
|
||||
NUM_METADATA = 8
|
||||
|
||||
# Note: For padding tokens IDX = -1
|
||||
# For cls tokens, IDX = -2
|
||||
ID_CLS_TOKEN = -2
|
||||
ID_PAD_TOKEN = -1
|
||||
|
||||
|
||||
class VisionEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_size: Tuple[int, int],
|
||||
patch_size: Tuple[int, int],
|
||||
dim: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float,
|
||||
in_channels: int = 3,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (
|
||||
self.image_size[0] // self.patch_size[0],
|
||||
self.image_size[1] // self.patch_size[1],
|
||||
)
|
||||
self.conv1 = ColumnParallelConv2dPatch(
|
||||
in_channels=in_channels,
|
||||
out_channels=dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False,
|
||||
)
|
||||
scale = dim**-0.5
|
||||
self.class_embedding = nn.Parameter(scale * torch.randn(dim))
|
||||
|
||||
self.positional_embedding_vlm = nn.Parameter(
|
||||
scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, dim)
|
||||
)
|
||||
|
||||
self.ln_pre = LayerNorm(dim)
|
||||
self.ln_post = LayerNorm(dim)
|
||||
self.transformer = _Transformer(
|
||||
dim,
|
||||
layers,
|
||||
heads,
|
||||
mlp_ratio,
|
||||
act_layer=nn.GELU,
|
||||
)
|
||||
|
||||
# NOTE: hack for the fixed res
|
||||
image_h, image_w = self.image_size
|
||||
patch_h, patch_w = self.patch_size
|
||||
idx_h, idx_w = image_h // patch_h, image_w // patch_w
|
||||
img_idx = torch.arange(image_h * image_w // (patch_h * patch_w), dtype=torch.int32)
|
||||
img_idx = img_idx.reshape(idx_h * idx_w, 1)
|
||||
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
|
||||
img_idx[-1, -1] = PackingIndex.ID_CLS_TOKEN
|
||||
|
||||
packed_img_idx = torch.empty(
|
||||
img_idx.shape[0],
|
||||
img_idx.shape[1],
|
||||
PackingIndex.NUM_METADATA - 1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
packed_img_idx[:, :, PackingIndex.Y] = img_idx // idx_w
|
||||
packed_img_idx[:, :, PackingIndex.X] = img_idx % idx_w
|
||||
packed_img_idx[:, :, PackingIndex.HEIGHT].fill_(idx_h)
|
||||
packed_img_idx[:, :, PackingIndex.WIDTH].fill_(idx_w)
|
||||
packed_img_idx[:, :, PackingIndex.IDX] = img_idx
|
||||
packed_img_idx = packed_img_idx.reshape(1, -1, PackingIndex.NUM_METADATA - 1)
|
||||
self.packed_img_idx = packed_img_idx # for positional embedding load hook
|
||||
|
||||
# compute rope freqs
|
||||
rope_freq = self.get_rope_freqs(dim // heads // 2)
|
||||
freqs_x = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.X] + 1)
|
||||
freqs_y = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.Y] + 1)
|
||||
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
|
||||
# disable RoPE for padding and cls tokens
|
||||
freqs = freqs.masked_fill(packed_img_idx[:, :, PackingIndex.IDX, None] < 0, 0)
|
||||
# compute complex freqs
|
||||
self.freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
|
||||
# xlf automatically broadcasts
|
||||
self.freq_cis = self.freq_cis.squeeze(0)
|
||||
self.n_heads = heads // fs_init.get_model_parallel_world_size()
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def get_rope_freqs(self, dim, theta=10000):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
return freqs
|
||||
|
||||
@torch.amp.autocast("cuda", enabled=False)
|
||||
def compute_rope_freqs(self, freqs, t):
|
||||
freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
|
||||
freqs = freqs.repeat_interleave(2, dim=-1)
|
||||
return freqs
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool = True,
|
||||
missing_keys: List[str] = None,
|
||||
unexpected_keys: List[str] = None,
|
||||
error_msgs: List[str] = None,
|
||||
return_state_dict: bool = False,
|
||||
) -> None:
|
||||
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
|
||||
if orig_pos_embed is not None and orig_pos_embed.shape[-2:] != self.positional_embedding_vlm.shape[-2:]:
|
||||
raise ValueError(
|
||||
f"Positional embedding shape {orig_pos_embed.shape} does not match expected shape {self.positional_embedding_vlm.shape}"
|
||||
)
|
||||
|
||||
batch_size, token_per_image, _ = self.packed_img_idx.shape
|
||||
# Input points for idx are [x, y, w, h]
|
||||
idx = self.packed_img_idx.reshape(batch_size * token_per_image, 1, -1)
|
||||
total_windows, window_size, _ = idx.shape
|
||||
|
||||
# Grid values are [-1, 1] and coords are w, h
|
||||
grid = (
|
||||
(idx[:, :, [PackingIndex.X, PackingIndex.Y]] / idx[:, :, [PackingIndex.WIDTH, PackingIndex.HEIGHT]]) * 2 - 1
|
||||
)[None, ...]
|
||||
|
||||
# In this mode, cls token has no position embedding
|
||||
if orig_pos_embed is not None:
|
||||
posemb = (
|
||||
orig_pos_embed[1:].view(1, self.grid_size[0], self.grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
|
||||
)
|
||||
posemb = posemb.to(device=grid.device, dtype=grid.dtype)
|
||||
sample = F.grid_sample(
|
||||
posemb, grid, padding_mode="zeros"
|
||||
) # padding tokens / class token will get zero for posemb
|
||||
sample = sample.view(-1, total_windows, window_size).permute(1, 2, 0).contiguous()
|
||||
sample = torch.where(
|
||||
idx[:, :, PackingIndex.IDX, None] == PackingIndex.ID_CLS_TOKEN,
|
||||
orig_pos_embed[0].view(1, 1, -1).to(device=sample.device, dtype=sample.dtype),
|
||||
sample,
|
||||
)
|
||||
|
||||
new_pos_embed = sample.reshape(batch_size, token_per_image, -1)
|
||||
|
||||
state_dict[prefix + "positional_embedding_vlm"] = new_pos_embed.squeeze(0)
|
||||
|
||||
if return_state_dict:
|
||||
return state_dict
|
||||
|
||||
def apply_class_embedding(self, x):
|
||||
x = torch.cat(
|
||||
[
|
||||
x,
|
||||
self.class_embedding.to(x.dtype)
|
||||
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
||||
],
|
||||
dim=1,
|
||||
) # shape = [*, grid ** 2 + 1, width]
|
||||
return x
|
||||
|
||||
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: in Llama4 bsz=bsz*num_tiles, num_chunks=1
|
||||
if images.ndim == 5:
|
||||
num_concurrent_media = 1
|
||||
bsz, num_chunks, nch, h, w = images.shape
|
||||
else:
|
||||
bsz, num_concurrent_media, num_chunks, nch, h, w = images.shape
|
||||
|
||||
images = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
|
||||
# patch embedding
|
||||
x = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
|
||||
x = self.conv1(x) # shape = [*, width, grid ** 2]
|
||||
_, ntok, dim = x.shape
|
||||
x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim)
|
||||
|
||||
# apply cls token
|
||||
x = self.apply_class_embedding(x)
|
||||
ntok += 1
|
||||
|
||||
# apply position embeddings
|
||||
if self.positional_embedding_vlm is not None:
|
||||
x = x + self.positional_embedding_vlm.to(x.dtype)
|
||||
|
||||
x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
|
||||
|
||||
x = self.ln_pre(x)
|
||||
x = x.view(bsz * num_concurrent_media, -1, dim)
|
||||
freq_cis = self.freq_cis.to(images.device)
|
||||
|
||||
tf_output = self.transformer(
|
||||
x,
|
||||
freq_cis=freq_cis,
|
||||
)
|
||||
|
||||
int_x = None
|
||||
if isinstance(tf_output, tuple):
|
||||
x, int_x = tf_output
|
||||
else:
|
||||
x = tf_output
|
||||
x = self.ln_post(x)
|
||||
|
||||
# remove cls token output
|
||||
x = x[:, :-1, :]
|
||||
|
||||
# add and output x + int_x features
|
||||
if int_x is not None:
|
||||
int_x = int_x[:, :-1, :, :]
|
||||
int_x = int_x.reshape(bsz * num_concurrent_media, ntok - 1, -1)
|
||||
x = torch.cat([x, int_x], dim=-1)
|
||||
|
||||
return x
|
|
@ -28,9 +28,6 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
|
||||
LLMInput,
|
||||
)
|
||||
|
||||
from .llama3.interface import LLama31Interface
|
||||
from .llama3.template_data import (
|
||||
|
@ -76,21 +73,22 @@ class UseCase(BaseModel):
|
|||
text += dialog
|
||||
text += "\n\n"
|
||||
continue
|
||||
|
||||
elif isinstance(dialog, TextCompletionContent):
|
||||
input_tokens, output_tokens = generator.text_completion_raw(
|
||||
dialog.content,
|
||||
temperature=0.1,
|
||||
top_p=0.95,
|
||||
max_gen_len=64,
|
||||
)
|
||||
else:
|
||||
input_tokens, output_tokens = generator.chat_completion_raw(
|
||||
dialog,
|
||||
temperature=0.0,
|
||||
top_p=0.95,
|
||||
max_gen_len=self.max_gen_len,
|
||||
batch = [dialog]
|
||||
method = (
|
||||
generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion
|
||||
)
|
||||
input_tokens = []
|
||||
output_tokens = []
|
||||
for token_results in method(batch, echo=True, temperature=0.1, top_p=0.95):
|
||||
result = token_results[0]
|
||||
if result.source == "input":
|
||||
input_tokens.append(result.token)
|
||||
else:
|
||||
output_tokens.append(result.token)
|
||||
|
||||
if result.finished:
|
||||
break
|
||||
text += "##### Input Prompt Format\n"
|
||||
|
||||
# FIXME: This is added to undo the hack in chat_formatter where
|
||||
|
@ -126,27 +124,27 @@ class Llama4UseCase(UseCase):
|
|||
|
||||
text = ""
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
temperature = 0.0
|
||||
for dialog in self.dialogs:
|
||||
if isinstance(dialog, str):
|
||||
text += dialog
|
||||
text += "\n\n"
|
||||
continue
|
||||
|
||||
elif isinstance(dialog, TextCompletionContent):
|
||||
# TODO pass the raw input and do the encoding in the text completion function
|
||||
input_tokens = tokenizer.encode(dialog.content, bos=True, eos=False)
|
||||
llm_input = LLMInput(tokens=input_tokens)
|
||||
output_tokens, decoded_tokens, token_logprobs = generator.text_completion_raw(
|
||||
llm_input, temperature=temperature, max_gen_len=self.max_gen_len
|
||||
)
|
||||
|
||||
else:
|
||||
input_tokens, output_tokens = generator.chat_completion_raw(
|
||||
dialog,
|
||||
temperature=temperature,
|
||||
max_gen_len=self.max_gen_len,
|
||||
batch = [dialog]
|
||||
method = (
|
||||
generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion
|
||||
)
|
||||
input_tokens = []
|
||||
output_tokens = []
|
||||
for token_results in method(batch, echo=True, temperature=0.0):
|
||||
result = token_results[0]
|
||||
if result.source == "input":
|
||||
input_tokens.append(result.token)
|
||||
else:
|
||||
output_tokens.append(result.token)
|
||||
|
||||
if result.finished:
|
||||
break
|
||||
|
||||
text += "##### Input Prompt Format\n"
|
||||
text += _code_block(tokenizer.decode(input_tokens))
|
||||
|
|
332
llama_stack/models/llama/quantize_impls.py
Normal file
332
llama_stack/models/llama/quantize_impls.py
Normal file
|
@ -0,0 +1,332 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# type: ignore
|
||||
import collections
|
||||
import logging
|
||||
from typing import Optional, Tuple, Type, Union
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
|
||||
log.info("Using efficient FP8 or INT4 operators in FBGEMM.")
|
||||
except ImportError:
|
||||
log.error("No efficient FP8 or INT4 operators. Please install FBGEMM.")
|
||||
raise
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class Fp8ScaledWeights:
|
||||
# TODO: Ugly trick so torch allows us to replace parameters
|
||||
# with our custom Fp8Weights instance. Do this properly.
|
||||
@property
|
||||
def __class__(self) -> Type[nn.parameter.Parameter]:
|
||||
return nn.Parameter
|
||||
|
||||
@property
|
||||
def grad_fn(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
|
||||
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
|
||||
class Fp8RowwiseWeights(
|
||||
Fp8ScaledWeights,
|
||||
collections.namedtuple(
|
||||
"Fp8RowwiseWeights",
|
||||
["weight", "scale", "shape", "activation_scale_ub"],
|
||||
),
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class Int4ScaledWeights:
|
||||
# TODO: Ugly trick so torch allows us to replace parameters
|
||||
# with our custom Int4Weights instance. Do this properly.
|
||||
@property
|
||||
def __class__(self) -> Type[nn.parameter.Parameter]:
|
||||
return nn.Parameter
|
||||
|
||||
@property
|
||||
def grad_fn(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
|
||||
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
|
||||
class Int4Weights(
|
||||
Int4ScaledWeights,
|
||||
collections.namedtuple(
|
||||
"Int4Weights",
|
||||
["weight", "scale", "zero_point", "shape", "activation_scale_ub"],
|
||||
),
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def int4_row_quantize(
|
||||
x: torch.Tensor,
|
||||
group_size: int = 128,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
n_bit = 4 # Number of target bits.
|
||||
to_quant = x.reshape(-1, group_size).to(torch.float)
|
||||
|
||||
max_val = to_quant.amax(dim=1, keepdim=True)
|
||||
min_val = to_quant.amin(dim=1, keepdim=True)
|
||||
max_int = 2**n_bit - 1
|
||||
min_int = 0
|
||||
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
||||
|
||||
zeros = min_val + scales * (2 ** (n_bit - 1))
|
||||
|
||||
out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
|
||||
|
||||
# Recenter output and move to int8.
|
||||
out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape)
|
||||
|
||||
# Cutlass expects column major layout for scale and zero point,
|
||||
# so we transpose here and make them contiguous.
|
||||
scales = scales.view(x.shape[0], -1).t().contiguous()
|
||||
zeros = zeros.view(x.shape[0], -1).t().contiguous()
|
||||
|
||||
return out, scales, zeros
|
||||
|
||||
|
||||
def pack_int4(x: torch.Tensor) -> torch.Tensor:
|
||||
# Given int8 x, pack adjacent int4 values into a single int8.
|
||||
low_x = x[:, ::2]
|
||||
high_x = x[:, 1::2]
|
||||
|
||||
# High bits need to left shift, this also masks off extra bits.
|
||||
high_x = torch.bitwise_left_shift(high_x, 4)
|
||||
# Low bits need to have sign bits removed.
|
||||
low_x = torch.bitwise_and(low_x, 0xF)
|
||||
|
||||
# Recombine into a single value with bitwise or.
|
||||
return torch.bitwise_or(low_x, high_x).contiguous()
|
||||
|
||||
|
||||
def bmm_nt(
|
||||
x: Tensor,
|
||||
w: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
if isinstance(w, Fp8ScaledWeights):
|
||||
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, w.activation_scale_ub)
|
||||
return torch.ops.fbgemm.f8f8bf16_rowwise_batched(xq, w.weight, x_scale, w.scale)
|
||||
elif isinstance(w, Int4ScaledWeights):
|
||||
return torch.ops.fbgemm.bf16i4bf16_rowwise_batched(x, w.weight, w.scale, w.zero_point)
|
||||
else:
|
||||
raise ValueError("Unsupported quantization type")
|
||||
|
||||
|
||||
def ffn_swiglu(
|
||||
x: Tensor,
|
||||
w1: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w3: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w2: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
if (isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights)) or (
|
||||
isinstance(w1, Int4ScaledWeights) and isinstance(w3, Int4ScaledWeights) and isinstance(w2, Int4ScaledWeights)
|
||||
):
|
||||
return ffn_swiglu_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
|
||||
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
(HD_L, D_) = w1.shape # noqa: N806
|
||||
assert D_ == D
|
||||
|
||||
assert isinstance(w1, Tensor)
|
||||
assert isinstance(w3, Tensor)
|
||||
x1 = x.view(B * T, D) @ w1.T
|
||||
x2 = x.view(B * T, D) @ w3.T
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
del x1, x2
|
||||
assert isinstance(w2, Tensor)
|
||||
return (z @ w2.T).view(B, T, D)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def quantize_fp8(
|
||||
w: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Quantize [n, k] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k] input high precision tensor to quantize.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device=output_device,
|
||||
)
|
||||
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
||||
del w
|
||||
return Fp8RowwiseWeights(
|
||||
weight=wq,
|
||||
scale=w_scale,
|
||||
shape=wq.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def quantize_int4(
|
||||
w: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Int4Weights:
|
||||
"""Quantize [n, k/2] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k/2] input high precision tensor to quantize.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device=output_device,
|
||||
)
|
||||
if w.ndim >= 3:
|
||||
wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False)
|
||||
wq = torch.stack([pack_int4(i) for i in wq], dim=0)
|
||||
scale = torch.stack(scale, dim=0)
|
||||
zero_point = torch.stack(zero_point, dim=0)
|
||||
else:
|
||||
wq, scale, zero_point = int4_row_quantize(w)
|
||||
wq = pack_int4(wq)
|
||||
del w
|
||||
return Int4Weights(
|
||||
weight=wq.to(output_device),
|
||||
scale=scale.to(output_device),
|
||||
zero_point=zero_point.to(output_device),
|
||||
shape=wq.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def load_fp8(
|
||||
w: Tensor,
|
||||
w_scale: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Load FP8 [n, k] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k] input FP8.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device=output_device,
|
||||
)
|
||||
return Fp8RowwiseWeights(
|
||||
weight=w.to(torch.float8_e4m3fn).to(device=output_device),
|
||||
scale=w_scale.to(device=output_device),
|
||||
shape=w.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def load_int4(
|
||||
w: Tensor,
|
||||
scale: Tensor,
|
||||
zero_point: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Int4Weights:
|
||||
"""Load INT4 [n, k/2] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k/2] input INT4.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device=output_device,
|
||||
)
|
||||
return Int4Weights(
|
||||
weight=w.to(torch.int8).to(device=output_device),
|
||||
scale=scale.to(device=output_device),
|
||||
zero_point=zero_point.to(device=output_device),
|
||||
shape=w.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
def fc_dynamic(
|
||||
x: Tensor,
|
||||
w: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
activation_scale_ub: Optional[Tensor] = None,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Single w8a8 fc layer with dynamic row-wise scaling, or w4a16 fc layer with dyanmic row-wise scaling
|
||||
"""
|
||||
if isinstance(w, Int4Weights):
|
||||
y = torch.ops.fbgemm.bf16i4bf16_rowwise(x, w.weight, w.scale, w.zero_point)
|
||||
else:
|
||||
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
|
||||
y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
|
||||
del xq
|
||||
return y
|
||||
|
||||
|
||||
def ffn_swiglu_dynamic(
|
||||
x: Tensor,
|
||||
w1: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w3: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w2: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
activation_scale_ub: Optional[Tensor] = None,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
assert x.dim() == 3 or x.dim() == 2
|
||||
if x.dim() == 3:
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
else:
|
||||
(T, D) = x.shape # noqa: N806
|
||||
B = 1 # noqa: N806
|
||||
|
||||
HD_L = w1.shape[0] # noqa: N806
|
||||
assert HD_L == w3.shape[0]
|
||||
x1 = fc_dynamic(
|
||||
x.view(B * T, D),
|
||||
w1,
|
||||
activation_scale_ub,
|
||||
num_tokens,
|
||||
is_memory_bounded,
|
||||
)
|
||||
x2 = fc_dynamic(
|
||||
x.view(B * T, D),
|
||||
w3,
|
||||
activation_scale_ub,
|
||||
num_tokens,
|
||||
is_memory_bounded,
|
||||
)
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
del x1, x2
|
||||
|
||||
z_ = fc_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
|
||||
|
||||
if x.dim() == 3:
|
||||
return z_.view(B, T, D)
|
||||
else:
|
||||
return z_
|
|
@ -4,24 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
from .datatypes import (
|
||||
from .sku_types import (
|
||||
CheckpointQuantizationFormat,
|
||||
CoreModelId,
|
||||
Model,
|
||||
ModelFamily,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
|
||||
LLAMA2_VOCAB_SIZE = 32000
|
||||
|
@ -47,15 +38,6 @@ def all_registered_models() -> List[Model]:
|
|||
)
|
||||
|
||||
|
||||
def recommended_sampling_params() -> SamplingParams:
|
||||
return SamplingParams(
|
||||
strategy=TopPSamplingStrategy(
|
||||
temperature=1.0,
|
||||
top_p=0.9,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def llama2_family() -> List[Model]:
|
||||
return [
|
||||
*llama2_base_models(),
|
||||
|
@ -150,7 +132,6 @@ def llama2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama2_7b,
|
||||
description="Llama 2 7b model",
|
||||
huggingface_repo="meta-llama/Llama-2-7b",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -169,7 +150,6 @@ def llama2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama2_13b,
|
||||
description="Llama 2 13b model",
|
||||
huggingface_repo="meta-llama/Llama-2-13b",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 5120,
|
||||
"n_layers": 40,
|
||||
|
@ -188,7 +168,6 @@ def llama2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama2_70b,
|
||||
description="Llama 2 70b model",
|
||||
huggingface_repo="meta-llama/Llama-2-70b",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -230,7 +209,6 @@ def llama3_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_70b,
|
||||
description="Llama 3 70b model",
|
||||
huggingface_repo="meta-llama/Llama-3-70B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -254,7 +232,6 @@ def llama3_1_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_1_8b,
|
||||
description="Llama 3.1 8b model",
|
||||
huggingface_repo="meta-llama/Llama-3.1-8B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -273,7 +250,6 @@ def llama3_1_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_1_70b,
|
||||
description="Llama 3.1 70b model",
|
||||
huggingface_repo="meta-llama/Llama-3.1-70B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -293,7 +269,6 @@ def llama3_1_base_models() -> List[Model]:
|
|||
variant="bf16-mp8",
|
||||
description="Llama 3.1 405b model (BF16 weights)",
|
||||
huggingface_repo="meta-llama/Llama-3.1-405B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 16384,
|
||||
"n_layers": 126,
|
||||
|
@ -313,7 +288,6 @@ def llama3_1_base_models() -> List[Model]:
|
|||
description="Llama 3.1 405b model (FP8 quantized)",
|
||||
huggingface_repo="meta-llama/Llama-3.1-405B-FP8",
|
||||
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 16384,
|
||||
"n_layers": 126,
|
||||
|
@ -333,7 +307,6 @@ def llama3_1_base_models() -> List[Model]:
|
|||
variant="bf16-mp16",
|
||||
description="Llama 3.1 405b model (BF16 weights for mp16)",
|
||||
huggingface_repo="meta-llama/Llama-3.1-405B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 16384,
|
||||
"n_layers": 126,
|
||||
|
@ -357,7 +330,6 @@ def llama3_2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_1b,
|
||||
description="Llama 3.2 1b model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-1B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 2048,
|
||||
"n_layers": 16,
|
||||
|
@ -376,7 +348,6 @@ def llama3_2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_3b,
|
||||
description="Llama 3.2 3b model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-3B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 3072,
|
||||
"n_layers": 28,
|
||||
|
@ -395,7 +366,6 @@ def llama3_2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_11b_vision,
|
||||
description="Llama 3.2 11b vision model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-11B-Vision",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -417,7 +387,6 @@ def llama3_2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_90b_vision,
|
||||
description="Llama 3.2 90b vision model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-90B-Vision",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -444,7 +413,6 @@ def llama2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama2_7b_chat,
|
||||
description="Llama 2 7b chat model",
|
||||
huggingface_repo="meta-llama/Llama-2-7b-chat",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -463,7 +431,6 @@ def llama2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama2_13b_chat,
|
||||
description="Llama 2 13b chat model",
|
||||
huggingface_repo="meta-llama/Llama-2-13b-chat",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 5120,
|
||||
"n_layers": 40,
|
||||
|
@ -482,7 +449,6 @@ def llama2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama2_70b_chat,
|
||||
description="Llama 2 70b chat model",
|
||||
huggingface_repo="meta-llama/Llama-2-70b-chat",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -506,7 +472,6 @@ def llama3_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_8b_instruct,
|
||||
description="Llama 3 8b instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3-8B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -525,7 +490,6 @@ def llama3_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_70b_instruct,
|
||||
description="Llama 3 70b instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3-70B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -549,7 +513,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_1_8b_instruct,
|
||||
description="Llama 3.1 8b instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3.1-8B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -568,7 +531,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_1_70b_instruct,
|
||||
description="Llama 3.1 70b instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3.1-70B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -588,7 +550,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
|||
variant="bf16-mp8",
|
||||
description="Llama 3.1 405b instruct model (BF16 weights)",
|
||||
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 16384,
|
||||
"n_layers": 126,
|
||||
|
@ -608,7 +569,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
|||
description="Llama 3.1 405b instruct model (FP8 quantized)",
|
||||
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct-FP8",
|
||||
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 16384,
|
||||
"n_layers": 126,
|
||||
|
@ -628,7 +588,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
|||
variant="bf16-mp16",
|
||||
description="Llama 3.1 405b instruct model (BF16 weights for mp16)",
|
||||
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 16384,
|
||||
"n_layers": 126,
|
||||
|
@ -684,7 +643,6 @@ def llama3_2_quantized_models() -> List[Model]:
|
|||
quantization_format=CheckpointQuantizationFormat.int4,
|
||||
description="Llama 3.2 1b INT4 quantized LoRA",
|
||||
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
**arch_args_1b(),
|
||||
"quantization_args": {
|
||||
|
@ -703,7 +661,6 @@ def llama3_2_quantized_models() -> List[Model]:
|
|||
quantization_format=CheckpointQuantizationFormat.int4,
|
||||
description="Llama 3.2 1b INT4 quantized SpinQuant",
|
||||
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
**arch_args_1b(),
|
||||
"quantization_args": {
|
||||
|
@ -718,7 +675,6 @@ def llama3_2_quantized_models() -> List[Model]:
|
|||
quantization_format=CheckpointQuantizationFormat.int4,
|
||||
description="Llama 3.2 3b INT4 quantized LoRA",
|
||||
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
**arch_args_3b(),
|
||||
"quantization_args": {
|
||||
|
@ -737,7 +693,6 @@ def llama3_2_quantized_models() -> List[Model]:
|
|||
quantization_format=CheckpointQuantizationFormat.int4,
|
||||
description="Llama 3.2 3b INT4 quantized SpinQuant",
|
||||
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-SpinQuant_INT4_EO8",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
**arch_args_3b(),
|
||||
"quantization_args": {
|
||||
|
@ -755,7 +710,6 @@ def llama3_2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_1b_instruct,
|
||||
description="Llama 3.2 1b instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args=arch_args_1b(),
|
||||
pth_file_count=1,
|
||||
),
|
||||
|
@ -763,7 +717,6 @@ def llama3_2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_3b_instruct,
|
||||
description="Llama 3.2 3b instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args=arch_args_3b(),
|
||||
pth_file_count=1,
|
||||
),
|
||||
|
@ -772,7 +725,6 @@ def llama3_2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_11b_vision_instruct,
|
||||
description="Llama 3.2 11b vision instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -794,7 +746,6 @@ def llama3_2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_90b_vision_instruct,
|
||||
description="Llama 3.2 90b vision instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -821,7 +772,6 @@ def llama3_3_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_3_70b_instruct,
|
||||
description="Llama 3.3 70b instruct",
|
||||
huggingface_repo="meta-llama/Llama-3.3-70B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -846,7 +796,6 @@ def safety_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama_guard_3_11b_vision,
|
||||
description="Llama Guard v3 11b vision system safety model",
|
||||
huggingface_repo="meta-llama/Llama-Guard-3-11B-Vision",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -870,7 +819,6 @@ def safety_models() -> List[Model]:
|
|||
description="Llama Guard v3 1b 'int4' quantized system safety model",
|
||||
huggingface_repo="meta-llama/Llama-Guard-3-1B-INT4",
|
||||
quantization_format=CheckpointQuantizationFormat.int4,
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 2048,
|
||||
"n_layers": 12,
|
||||
|
@ -888,7 +836,6 @@ def safety_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama_guard_3_1b,
|
||||
description="Llama Guard v3 1b system safety model",
|
||||
huggingface_repo="meta-llama/Llama-Guard-3-1B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 2048,
|
||||
"n_layers": 16,
|
||||
|
|
229
llama_stack/models/llama/sku_types.py
Normal file
229
llama_stack/models/llama/sku_types.py
Normal file
|
@ -0,0 +1,229 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class CheckpointQuantizationFormat(Enum):
|
||||
# default format
|
||||
bf16 = "bf16"
|
||||
|
||||
# used for enabling fp8_rowwise inference, some weights are bf16
|
||||
fp8_mixed = "fp8-mixed"
|
||||
|
||||
int8 = "int8"
|
||||
|
||||
int4 = "int4"
|
||||
|
||||
|
||||
class ModelFamily(Enum):
|
||||
llama2 = "llama2"
|
||||
llama3 = "llama3"
|
||||
llama3_1 = "llama3_1"
|
||||
llama3_2 = "llama3_2"
|
||||
llama3_3 = "llama3_3"
|
||||
llama4 = "llama4"
|
||||
safety = "safety"
|
||||
|
||||
|
||||
class CoreModelId(Enum):
|
||||
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
|
||||
|
||||
# Llama 2 family
|
||||
llama2_7b = "Llama-2-7b"
|
||||
llama2_13b = "Llama-2-13b"
|
||||
llama2_70b = "Llama-2-70b"
|
||||
llama2_7b_chat = "Llama-2-7b-chat"
|
||||
llama2_13b_chat = "Llama-2-13b-chat"
|
||||
llama2_70b_chat = "Llama-2-70b-chat"
|
||||
|
||||
# Llama 3 family
|
||||
llama3_8b = "Llama-3-8B"
|
||||
llama3_70b = "Llama-3-70B"
|
||||
llama3_8b_instruct = "Llama-3-8B-Instruct"
|
||||
llama3_70b_instruct = "Llama-3-70B-Instruct"
|
||||
|
||||
# Llama 3.1 family
|
||||
llama3_1_8b = "Llama3.1-8B"
|
||||
llama3_1_70b = "Llama3.1-70B"
|
||||
llama3_1_405b = "Llama3.1-405B"
|
||||
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
|
||||
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
|
||||
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
|
||||
|
||||
# Llama 3.2 family
|
||||
llama3_2_1b = "Llama3.2-1B"
|
||||
llama3_2_3b = "Llama3.2-3B"
|
||||
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
|
||||
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
|
||||
llama3_2_11b_vision = "Llama3.2-11B-Vision"
|
||||
llama3_2_90b_vision = "Llama3.2-90B-Vision"
|
||||
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
|
||||
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
|
||||
|
||||
# Llama 3.3 family
|
||||
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
|
||||
|
||||
# Llama 4 family
|
||||
llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
|
||||
llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
|
||||
llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
|
||||
llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
|
||||
|
||||
# Safety models
|
||||
llama_guard_3_8b = "Llama-Guard-3-8B"
|
||||
llama_guard_2_8b = "Llama-Guard-2-8B"
|
||||
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
|
||||
llama_guard_3_1b = "Llama-Guard-3-1B"
|
||||
|
||||
|
||||
def is_multimodal(model_id) -> bool:
|
||||
if model_id in [
|
||||
CoreModelId.llama3_2_11b_vision,
|
||||
CoreModelId.llama3_2_90b_vision,
|
||||
CoreModelId.llama3_2_11b_vision_instruct,
|
||||
CoreModelId.llama3_2_90b_vision_instruct,
|
||||
]:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def model_family(model_id) -> ModelFamily:
|
||||
if model_id in [
|
||||
CoreModelId.llama2_7b,
|
||||
CoreModelId.llama2_13b,
|
||||
CoreModelId.llama2_70b,
|
||||
CoreModelId.llama2_7b_chat,
|
||||
CoreModelId.llama2_13b_chat,
|
||||
CoreModelId.llama2_70b_chat,
|
||||
]:
|
||||
return ModelFamily.llama2
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_8b,
|
||||
CoreModelId.llama3_70b,
|
||||
CoreModelId.llama3_8b_instruct,
|
||||
CoreModelId.llama3_70b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_1_8b,
|
||||
CoreModelId.llama3_1_70b,
|
||||
CoreModelId.llama3_1_405b,
|
||||
CoreModelId.llama3_1_8b_instruct,
|
||||
CoreModelId.llama3_1_70b_instruct,
|
||||
CoreModelId.llama3_1_405b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_1
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_2_1b,
|
||||
CoreModelId.llama3_2_3b,
|
||||
CoreModelId.llama3_2_1b_instruct,
|
||||
CoreModelId.llama3_2_3b_instruct,
|
||||
CoreModelId.llama3_2_11b_vision,
|
||||
CoreModelId.llama3_2_90b_vision,
|
||||
CoreModelId.llama3_2_11b_vision_instruct,
|
||||
CoreModelId.llama3_2_90b_vision_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_2
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_3_70b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_3
|
||||
elif model_id in [
|
||||
CoreModelId.llama4_scout_17b_16e,
|
||||
CoreModelId.llama4_scout_17b_16e_instruct,
|
||||
CoreModelId.llama4_maverick_17b_128e,
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct,
|
||||
]:
|
||||
return ModelFamily.llama4
|
||||
elif model_id in [
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_2_8b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
]:
|
||||
return ModelFamily.safety
|
||||
else:
|
||||
raise ValueError(f"Unknown model family for {model_id}")
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
core_model_id: CoreModelId
|
||||
description: str
|
||||
huggingface_repo: Optional[str] = None
|
||||
arch_args: Dict[str, Any]
|
||||
variant: str = ""
|
||||
|
||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||
pth_file_count: int
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# silence pydantic until we remove the `model_` fields
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def model_family(self) -> ModelFamily:
|
||||
return model_family(self.core_model_id)
|
||||
|
||||
# The SKU is uniquely identified by (model_id, variant) combo
|
||||
def descriptor(self, shorten_default_variant: bool = True) -> str:
|
||||
if not self.variant:
|
||||
return self.core_model_id.value
|
||||
return f"{self.core_model_id.value}:{self.variant}"
|
||||
|
||||
@property
|
||||
def is_instruct_model(self) -> bool:
|
||||
return "instruct" in self.core_model_id.value
|
||||
|
||||
# Featured models are shown in the non-exhaustive model list
|
||||
@property
|
||||
def is_featured(self) -> bool:
|
||||
return self.model_family in [
|
||||
ModelFamily.llama3_1,
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.llama4,
|
||||
ModelFamily.safety,
|
||||
]
|
||||
|
||||
@property
|
||||
def max_seq_length(self) -> int:
|
||||
if self.model_family == ModelFamily.llama2:
|
||||
return 4096
|
||||
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
|
||||
return 4096
|
||||
elif self.model_family == ModelFamily.llama3:
|
||||
return 8192
|
||||
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
|
||||
return 131072
|
||||
elif self.model_family == ModelFamily.llama3_2:
|
||||
if self.quantization_format == CheckpointQuantizationFormat.int4:
|
||||
return 8192
|
||||
return 131072
|
||||
elif self.model_family == ModelFamily.llama4:
|
||||
if self.core_model_id in {
|
||||
CoreModelId.llama4_scout_17b_16e,
|
||||
CoreModelId.llama4_maverick_17b_128e,
|
||||
}:
|
||||
return 262144
|
||||
if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
|
||||
return 10485760
|
||||
if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
|
||||
return 1048576
|
||||
|
||||
raise AssertionError(f"Unexpected core model id: {self.core_model_id}")
|
||||
elif self.core_model_id in [
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
]:
|
||||
return 131072
|
||||
else:
|
||||
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
|
Loading…
Add table
Add a link
Reference in a new issue