Nuke hardware_requirements from SKUs

This commit is contained in:
Ashwin Bharambe 2024-09-13 16:39:02 -07:00
parent d8b3fdbd54
commit 19a14cd273
4 changed files with 17 additions and 7 deletions

View file

@ -79,7 +79,7 @@ class Llama:
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
model_parallel_size = model.hardware_requirements.gpu_count
model_parallel_size = config.model_parallel_size
if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size)