llama-stack/llama_stack/models/llama/checkpoint.py
Ashwin Bharambe 530d4bdfe1
refactor: move all llama code to models/llama out of meta reference (#1887)
# What does this PR do?

Move around bits. This makes the copies from llama-models _much_ easier
to maintain and ensures we don't entangle meta-reference specific
tidbits into llama-models code even by accident.

Also, kills the meta-reference-quantized-gpu distro and rolls
quantization deps into meta-reference-gpu.

## Test Plan

```
LLAMA_MODELS_DEBUG=1 \
  with-proxy llama stack run meta-reference-gpu \
  --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \
   --env INFERENCE_CHECKPOINT_DIR=<DIR> \
   --env MODEL_PARALLEL_SIZE=4 \
   --env QUANTIZATION_TYPE=fp8_mixed
```

Start a server with and without quantization. Point integration tests to
it using:

```
pytest -s -v  tests/integration/inference/test_text_inference.py \
   --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct
```
2025-04-07 15:03:58 -07:00

164 lines
6.4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import concurrent.futures
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]:
"""Map a new MP rank to a list of old MP ranks given a change in MP size."""
if new_mp_size % old_mp_size == 0:
# Read old MP shard and split it into smaller ones
return [new_mp_rank * old_mp_size // new_mp_size]
elif old_mp_size % new_mp_size == 0:
# Merge old MP shards into a single one
mp_factor = old_mp_size // new_mp_size
return list(range(new_mp_rank * mp_factor, (new_mp_rank + 1) * mp_factor))
else:
raise ValueError(
f"Either old MP size or new MP size should be a multiple of the other: "
f"{old_mp_size} % {new_mp_size} != 0 and {new_mp_size} % {old_mp_size} != 0"
)
def maybe_reshard_state_dict(
ckpt_paths: List[Path],
n_kv_heads: int,
moe_num_experts: Optional[int] = None,
map_location: Union[str, torch.device] = "cpu",
mmap: bool = True,
) -> Dict[str, torch.Tensor]:
if str(map_location) == "cpu":
torch.set_default_tensor_type(torch.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
ckpt_paths = np.array(sorted(ckpt_paths))
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
old_mp_size = len(ckpt_paths)
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
paths = ckpt_paths[old_mp_ranks] # type: ignore
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
if new_mp_size == old_mp_size:
return state_dicts[0] # type: ignore
if moe_num_experts is not None:
state_dicts = [convert_moe_weights(d, moe_num_experts) for d in state_dicts]
print(f"Resharding {len(state_dicts)} state dicts from MP size {old_mp_size} to MP size {new_mp_size}")
return reshard_mp(
state_dicts,
size=max(new_mp_size // old_mp_size, 1),
rank=new_mp_rank % max(new_mp_size // old_mp_size, 1),
repeat_qk_qv=max(new_mp_size // n_kv_heads, 1),
)
_WEIGHT_ROW_KEY = {
"feed_forward.w2",
"feed_forward.mlp.fc2",
"attention.wo",
"feed_forward.mlp.fc2_weight",
"feed_forward.w_out_shared_DF.weight",
"attn.wo.weight",
"mlp.c_proj.weight",
}
_MOE_WEIGHT_ROW_KEY = {"feed_forward.experts.(moe_w_in_eD_F|moe_w_swiglu_eD_F)"}
_WEIGHT_COLUMN_KEY = {
"output",
"feed_forward.(w1|w3)",
"feed_forward.mlp.(fc1|fc3)",
"feed_forward.mlp.fc1_weight",
"attention.(wk|wq|wv|wqkv).weight",
"feed_forward.(w_in_shared_FD|w_swiglu_FD)",
"attn.(wk|wq|wv).weight",
"attn.(wk|wq|wv).bias",
"mlp.c_fc.weight",
"mlp.c_fc.bias",
"conv1._linear.weight",
"tok_embeddings.weight",
"vision_projection.weight",
}
_MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
def reshard_mp(
state_dicts: List[Dict[str, torch.Tensor]],
size: int,
rank: int,
repeat_qk_qv: int = 1,
) -> Dict[str, torch.Tensor]:
"""
Reshard a list of state dicts into a single state dict given a change in MP size.
If the list has more than one state dict, we concatenate the values of the same
key across all state dicts. Otherwise, we just slice it for the current MP rank.
"""
def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor:
if len(tensors) > 1:
return torch.cat(tensors, dim=dim)
return tensors[0].chunk(size, dim=dim)[rank].clone()
def process_key(key: str) -> torch.Tensor:
if row_regex.search(key):
return concat_or_chunk([s[key] for s in state_dicts], dim=-1)
elif column_regex.search(key):
if "w13" in key or "fc1_weight" in key:
dims = state_dicts[0][key].size()
values = [s[key].view(2, dims[0] // 2, *dims[1:]) for s in state_dicts]
return concat_or_chunk(values, dim=1).flatten(0, 1)
elif "qkv" in key:
q_dim = state_dicts[0][key.replace("qkv", "o")].size(1)
kv_dim = (state_dicts[0][key].size(0) - q_dim) // 2
values = [s[key].split((q_dim, kv_dim, kv_dim)) for s in state_dicts]
return torch.cat([concat_or_chunk(x, dim=0) for x in zip(*values, strict=False)]) # type: ignore
elif "wk.weight" in key or "wv.weight" in key:
# Support MP > #kv_head
return concat_or_chunk([s[key].repeat(repeat_qk_qv, 1) for s in state_dicts], dim=0)
elif key == "output.bias" or key == "fc.weight":
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
elif "w_" in key:
return concat_or_chunk([s[key] for s in state_dicts], dim=-2)
else:
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
else:
return state_dicts[0][key].clone()
row_keys = _WEIGHT_ROW_KEY | _MOE_WEIGHT_ROW_KEY
column_keys = _WEIGHT_COLUMN_KEY | _MOE_WEIGHT_COLUMN_KEY
column_regex = re.compile("|".join(column_keys))
row_regex = re.compile("|".join(row_keys))
output: Dict[str, torch.Tensor] = {}
with concurrent.futures.ThreadPoolExecutor() as executor:
# Note: only processes keys in the first state dict.
# Assumes keys are the same across all state dicts.
mappings = {executor.submit(process_key, key): key for key in state_dicts[0]}
for future in concurrent.futures.as_completed(mappings):
output[mappings[future]] = future.result()
return output
def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]:
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
routed_regex = re.compile("|".join(routed_keys))
keys = list(state_dict.keys())
for key in keys:
if routed_regex.search(key):
state_dict[key] = state_dict.pop(key).unflatten(0, (num_experts, -1)).squeeze(dim=0)
return state_dict