chore: add mypy inference fp8_impls

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-09 00:22:45 +02:00
parent d880c2df0e
commit 1c08a1cae9
7 changed files with 38 additions and 25 deletions

View file

@ -42,14 +42,14 @@ def maybe_reshard_state_dict(
else:
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
ckpt_paths = np.array(sorted(ckpt_paths))
ckpt_paths_array = np.array(sorted(ckpt_paths))
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
old_mp_size = len(ckpt_paths)
old_mp_size = len(ckpt_paths_array)
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
paths = ckpt_paths[old_mp_ranks] # type: ignore
print(f"Loading checkpoint shards:\n{str(ckpt_paths_array[old_mp_ranks])}")
paths = ckpt_paths_array[old_mp_ranks]
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
if new_mp_size == old_mp_size: