allow changing model parallel size

This commit is contained in:
Ashwin Bharambe 2025-04-07 11:34:28 -07:00
parent ff6c47d4e5
commit 63cf5dda50
5 changed files with 15 additions and 46 deletions

View file

@ -21,6 +21,7 @@ class MetaReferenceInferenceConfig(BaseModel):
torch_seed: Optional[int] = None
max_seq_len: int = 4096
max_batch_size: int = 1
model_parallel_size: Optional[int] = None
# when this is False, we assume that the distributed process group is setup by someone
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
@ -50,6 +51,7 @@ class MetaReferenceInferenceConfig(BaseModel):
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:null}",
**kwargs,
) -> Dict[str, Any]:
return {
@ -59,4 +61,5 @@ class MetaReferenceInferenceConfig(BaseModel):
"quantization": {
"type": quantization_type,
},
"model_parallel_size": model_parallel_size,
}

View file

@ -147,7 +147,7 @@ class Llama4Generator:
ckpt_dir=ckpt_dir,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
world_size=llama_model.pth_file_count,
world_size=config.model_parallel_size or llama_model.pth_file_count,
quantization_mode=quantization_mode,
)
@ -238,7 +238,7 @@ class Llama3Generator:
ckpt_dir=ckpt_dir,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
world_size=llama_model.pth_file_count,
world_size=config.model_parallel_size or llama_model.pth_file_count,
quantization_mode=quantization_mode,
)
self.tokenizer = self.inner_generator.tokenizer

View file

@ -356,50 +356,7 @@
"fairscale",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"lm-format-enforcer",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentence-transformers",
"sentencepiece",
"torch",
"torchvision",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"zmq"
],
"meta-reference-quantized-gpu": [
"accelerate",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fairscale",
"faiss-cpu",
"fastapi",
"fbgemm-gpu",
"fbgemm-gpu-genai==1.1.2",
"fire",
"httpx",
"langdetect",

View file

@ -18,6 +18,9 @@ providers:
model: ${env.INFERENCE_MODEL}
max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization:
type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
@ -27,6 +30,9 @@ providers:
model: ${env.SAFETY_MODEL}
max_seq_len: 4096
checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null}
quantization:
type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null}
vector_io:
- provider_id: faiss
provider_type: inline::faiss

View file

@ -18,6 +18,9 @@ providers:
model: ${env.INFERENCE_MODEL}
max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization:
type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}