mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-26 17:23:00 +00:00 
			
		
		
		
	# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
		
			
				
	
	
		
			164 lines
		
	
	
	
		
			6.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			164 lines
		
	
	
	
		
			6.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import concurrent.futures
 | |
| import re
 | |
| from pathlib import Path
 | |
| from typing import Any
 | |
| 
 | |
| import numpy as np
 | |
| import torch
 | |
| from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
 | |
| 
 | |
| 
 | |
| def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> list[int]:
 | |
|     """Map a new MP rank to a list of old MP ranks given a change in MP size."""
 | |
|     if new_mp_size % old_mp_size == 0:
 | |
|         # Read old MP shard and split it into smaller ones
 | |
|         return [new_mp_rank * old_mp_size // new_mp_size]
 | |
|     elif old_mp_size % new_mp_size == 0:
 | |
|         # Merge old MP shards into a single one
 | |
|         mp_factor = old_mp_size // new_mp_size
 | |
|         return list(range(new_mp_rank * mp_factor, (new_mp_rank + 1) * mp_factor))
 | |
|     else:
 | |
|         raise ValueError(
 | |
|             f"Either old MP size or new MP size should be a multiple of the other: "
 | |
|             f"{old_mp_size} % {new_mp_size} != 0 and {new_mp_size} % {old_mp_size} != 0"
 | |
|         )
 | |
| 
 | |
| 
 | |
| def maybe_reshard_state_dict(
 | |
|     ckpt_paths: list[Path],
 | |
|     n_kv_heads: int,
 | |
|     moe_num_experts: int | None = None,
 | |
|     map_location: str | torch.device = "cpu",
 | |
|     mmap: bool = True,
 | |
| ) -> dict[str, torch.Tensor]:
 | |
|     if str(map_location) == "cpu":
 | |
|         torch.set_default_tensor_type(torch.BFloat16Tensor)
 | |
|     else:
 | |
|         torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
 | |
| 
 | |
|     ckpt_paths = np.array(sorted(ckpt_paths))
 | |
| 
 | |
|     new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
 | |
|     old_mp_size = len(ckpt_paths)
 | |
|     old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
 | |
| 
 | |
|     print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}")  # type: ignore
 | |
|     paths = ckpt_paths[old_mp_ranks]  # type: ignore
 | |
|     state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
 | |
| 
 | |
|     if new_mp_size == old_mp_size:
 | |
|         return state_dicts[0]  # type: ignore
 | |
| 
 | |
|     if moe_num_experts is not None:
 | |
|         state_dicts = [convert_moe_weights(d, moe_num_experts) for d in state_dicts]
 | |
| 
 | |
|     print(f"Resharding {len(state_dicts)} state dicts from MP size {old_mp_size} to MP size {new_mp_size}")
 | |
|     return reshard_mp(
 | |
|         state_dicts,
 | |
|         size=max(new_mp_size // old_mp_size, 1),
 | |
|         rank=new_mp_rank % max(new_mp_size // old_mp_size, 1),
 | |
|         repeat_qk_qv=max(new_mp_size // n_kv_heads, 1),
 | |
|     )
 | |
| 
 | |
| 
 | |
| _WEIGHT_ROW_KEY = {
 | |
|     "feed_forward.w2",
 | |
|     "feed_forward.mlp.fc2",
 | |
|     "attention.wo",
 | |
|     "feed_forward.mlp.fc2_weight",
 | |
|     "feed_forward.w_out_shared_DF.weight",
 | |
|     "attn.wo.weight",
 | |
|     "mlp.c_proj.weight",
 | |
| }
 | |
| _MOE_WEIGHT_ROW_KEY = {"feed_forward.experts.(moe_w_in_eD_F|moe_w_swiglu_eD_F)"}
 | |
| 
 | |
| _WEIGHT_COLUMN_KEY = {
 | |
|     "output",
 | |
|     "feed_forward.(w1|w3)",
 | |
|     "feed_forward.mlp.(fc1|fc3)",
 | |
|     "feed_forward.mlp.fc1_weight",
 | |
|     "attention.(wk|wq|wv|wqkv).weight",
 | |
|     "feed_forward.(w_in_shared_FD|w_swiglu_FD)",
 | |
|     "attn.(wk|wq|wv).weight",
 | |
|     "attn.(wk|wq|wv).bias",
 | |
|     "mlp.c_fc.weight",
 | |
|     "mlp.c_fc.bias",
 | |
|     "conv1._linear.weight",
 | |
|     "tok_embeddings.weight",
 | |
|     "vision_projection.weight",
 | |
| }
 | |
| _MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
 | |
| 
 | |
| 
 | |
| def reshard_mp(
 | |
|     state_dicts: list[dict[str, torch.Tensor]],
 | |
|     size: int,
 | |
|     rank: int,
 | |
|     repeat_qk_qv: int = 1,
 | |
| ) -> dict[str, torch.Tensor]:
 | |
|     """
 | |
|     Reshard a list of state dicts into a single state dict given a change in MP size.
 | |
|     If the list has more than one state dict, we concatenate the values of the same
 | |
|     key across all state dicts. Otherwise, we just slice it for the current MP rank.
 | |
|     """
 | |
| 
 | |
|     def concat_or_chunk(tensors: list[torch.Tensor], dim: int) -> torch.Tensor:
 | |
|         if len(tensors) > 1:
 | |
|             return torch.cat(tensors, dim=dim)
 | |
|         return tensors[0].chunk(size, dim=dim)[rank].clone()
 | |
| 
 | |
|     def process_key(key: str) -> torch.Tensor:
 | |
|         if row_regex.search(key):
 | |
|             return concat_or_chunk([s[key] for s in state_dicts], dim=-1)
 | |
|         elif column_regex.search(key):
 | |
|             if "w13" in key or "fc1_weight" in key:
 | |
|                 dims = state_dicts[0][key].size()
 | |
|                 values = [s[key].view(2, dims[0] // 2, *dims[1:]) for s in state_dicts]
 | |
|                 return concat_or_chunk(values, dim=1).flatten(0, 1)
 | |
|             elif "qkv" in key:
 | |
|                 q_dim = state_dicts[0][key.replace("qkv", "o")].size(1)
 | |
|                 kv_dim = (state_dicts[0][key].size(0) - q_dim) // 2
 | |
|                 values = [s[key].split((q_dim, kv_dim, kv_dim)) for s in state_dicts]
 | |
|                 return torch.cat([concat_or_chunk(x, dim=0) for x in zip(*values, strict=False)])  # type: ignore
 | |
|             elif "wk.weight" in key or "wv.weight" in key:
 | |
|                 # Support MP > #kv_head
 | |
|                 return concat_or_chunk([s[key].repeat(repeat_qk_qv, 1) for s in state_dicts], dim=0)
 | |
|             elif key == "output.bias" or key == "fc.weight":
 | |
|                 return concat_or_chunk([s[key] for s in state_dicts], dim=0)
 | |
|             elif "w_" in key:
 | |
|                 return concat_or_chunk([s[key] for s in state_dicts], dim=-2)
 | |
|             else:
 | |
|                 return concat_or_chunk([s[key] for s in state_dicts], dim=0)
 | |
|         else:
 | |
|             return state_dicts[0][key].clone()
 | |
| 
 | |
|     row_keys = _WEIGHT_ROW_KEY | _MOE_WEIGHT_ROW_KEY
 | |
|     column_keys = _WEIGHT_COLUMN_KEY | _MOE_WEIGHT_COLUMN_KEY
 | |
| 
 | |
|     column_regex = re.compile("|".join(column_keys))
 | |
|     row_regex = re.compile("|".join(row_keys))
 | |
| 
 | |
|     output: dict[str, torch.Tensor] = {}
 | |
|     with concurrent.futures.ThreadPoolExecutor() as executor:
 | |
|         # Note: only processes keys in the first state dict.
 | |
|         # Assumes keys are the same across all state dicts.
 | |
|         mappings = {executor.submit(process_key, key): key for key in state_dicts[0]}
 | |
|         for future in concurrent.futures.as_completed(mappings):
 | |
|             output[mappings[future]] = future.result()
 | |
|     return output
 | |
| 
 | |
| 
 | |
| def convert_moe_weights(state_dict: dict[str, Any], num_experts: int) -> dict[str, Any]:
 | |
|     routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
 | |
|     routed_regex = re.compile("|".join(routed_keys))
 | |
|     keys = list(state_dict.keys())
 | |
|     for key in keys:
 | |
|         if routed_regex.search(key):
 | |
|             state_dict[key] = state_dict.pop(key).unflatten(0, (num_experts, -1)).squeeze(dim=0)
 | |
|     return state_dict
 |