mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 09:12:28 +00:00
chore: add mypy inference fp8_impls
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
d880c2df0e
commit
1c08a1cae9
7 changed files with 38 additions and 25 deletions
|
|
@ -42,14 +42,14 @@ def maybe_reshard_state_dict(
|
|||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
|
||||
ckpt_paths = np.array(sorted(ckpt_paths))
|
||||
ckpt_paths_array = np.array(sorted(ckpt_paths))
|
||||
|
||||
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
|
||||
old_mp_size = len(ckpt_paths)
|
||||
old_mp_size = len(ckpt_paths_array)
|
||||
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
|
||||
|
||||
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
|
||||
paths = ckpt_paths[old_mp_ranks] # type: ignore
|
||||
print(f"Loading checkpoint shards:\n{str(ckpt_paths_array[old_mp_ranks])}")
|
||||
paths = ckpt_paths_array[old_mp_ranks]
|
||||
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
|
||||
|
||||
if new_mp_size == old_mp_size:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue