allow changing model parallel size

This commit is contained in:
Ashwin Bharambe 2025-04-07 11:34:28 -07:00
parent ff6c47d4e5
commit 63cf5dda50
5 changed files with 15 additions and 46 deletions

View file

@ -21,6 +21,7 @@ class MetaReferenceInferenceConfig(BaseModel):
torch_seed: Optional[int] = None
max_seq_len: int = 4096
max_batch_size: int = 1
model_parallel_size: Optional[int] = None
# when this is False, we assume that the distributed process group is setup by someone
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
@ -50,6 +51,7 @@ class MetaReferenceInferenceConfig(BaseModel):
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:null}",
**kwargs,
) -> Dict[str, Any]:
return {
@ -59,4 +61,5 @@ class MetaReferenceInferenceConfig(BaseModel):
"quantization": {
"type": quantization_type,
},
"model_parallel_size": model_parallel_size,
}