allow changing model parallel size

This commit is contained in:
Ashwin Bharambe 2025-04-07 11:34:28 -07:00
parent ff6c47d4e5
commit 63cf5dda50
5 changed files with 15 additions and 46 deletions

View file

@ -21,6 +21,7 @@ class MetaReferenceInferenceConfig(BaseModel):
torch_seed: Optional[int] = None
max_seq_len: int = 4096
max_batch_size: int = 1
model_parallel_size: Optional[int] = None
# when this is False, we assume that the distributed process group is setup by someone
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
@ -50,6 +51,7 @@ class MetaReferenceInferenceConfig(BaseModel):
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:null}",
**kwargs,
) -> Dict[str, Any]:
return {
@ -59,4 +61,5 @@ class MetaReferenceInferenceConfig(BaseModel):
"quantization": {
"type": quantization_type,
},
"model_parallel_size": model_parallel_size,
}

View file

@ -147,7 +147,7 @@ class Llama4Generator:
ckpt_dir=ckpt_dir,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
world_size=llama_model.pth_file_count,
world_size=config.model_parallel_size or llama_model.pth_file_count,
quantization_mode=quantization_mode,
)
@ -238,7 +238,7 @@ class Llama3Generator:
ckpt_dir=ckpt_dir,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
world_size=llama_model.pth_file_count,
world_size=config.model_parallel_size or llama_model.pth_file_count,
quantization_mode=quantization_mode,
)
self.tokenizer = self.inner_generator.tokenizer