mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-01 18:44:31 +00:00
allow changing model parallel size
This commit is contained in:
parent
ff6c47d4e5
commit
63cf5dda50
5 changed files with 15 additions and 46 deletions
|
|
@ -21,6 +21,7 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
torch_seed: Optional[int] = None
|
||||
max_seq_len: int = 4096
|
||||
max_batch_size: int = 1
|
||||
model_parallel_size: Optional[int] = None
|
||||
|
||||
# when this is False, we assume that the distributed process group is setup by someone
|
||||
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
|
||||
|
|
@ -50,6 +51,7 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
model: str = "Llama3.2-3B-Instruct",
|
||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
||||
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
|
||||
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:null}",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
|
|
@ -59,4 +61,5 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
"quantization": {
|
||||
"type": quantization_type,
|
||||
},
|
||||
"model_parallel_size": model_parallel_size,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -147,7 +147,7 @@ class Llama4Generator:
|
|||
ckpt_dir=ckpt_dir,
|
||||
max_seq_len=config.max_seq_len,
|
||||
max_batch_size=config.max_batch_size,
|
||||
world_size=llama_model.pth_file_count,
|
||||
world_size=config.model_parallel_size or llama_model.pth_file_count,
|
||||
quantization_mode=quantization_mode,
|
||||
)
|
||||
|
||||
|
|
@ -238,7 +238,7 @@ class Llama3Generator:
|
|||
ckpt_dir=ckpt_dir,
|
||||
max_seq_len=config.max_seq_len,
|
||||
max_batch_size=config.max_batch_size,
|
||||
world_size=llama_model.pth_file_count,
|
||||
world_size=config.model_parallel_size or llama_model.pth_file_count,
|
||||
quantization_mode=quantization_mode,
|
||||
)
|
||||
self.tokenizer = self.inner_generator.tokenizer
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue