mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
allow changing model parallel size
This commit is contained in:
parent
ff6c47d4e5
commit
63cf5dda50
5 changed files with 15 additions and 46 deletions
|
@ -21,6 +21,7 @@ class MetaReferenceInferenceConfig(BaseModel):
|
||||||
torch_seed: Optional[int] = None
|
torch_seed: Optional[int] = None
|
||||||
max_seq_len: int = 4096
|
max_seq_len: int = 4096
|
||||||
max_batch_size: int = 1
|
max_batch_size: int = 1
|
||||||
|
model_parallel_size: Optional[int] = None
|
||||||
|
|
||||||
# when this is False, we assume that the distributed process group is setup by someone
|
# when this is False, we assume that the distributed process group is setup by someone
|
||||||
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
|
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
|
||||||
|
@ -50,6 +51,7 @@ class MetaReferenceInferenceConfig(BaseModel):
|
||||||
model: str = "Llama3.2-3B-Instruct",
|
model: str = "Llama3.2-3B-Instruct",
|
||||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
||||||
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
|
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
|
||||||
|
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:null}",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
@ -59,4 +61,5 @@ class MetaReferenceInferenceConfig(BaseModel):
|
||||||
"quantization": {
|
"quantization": {
|
||||||
"type": quantization_type,
|
"type": quantization_type,
|
||||||
},
|
},
|
||||||
|
"model_parallel_size": model_parallel_size,
|
||||||
}
|
}
|
||||||
|
|
|
@ -147,7 +147,7 @@ class Llama4Generator:
|
||||||
ckpt_dir=ckpt_dir,
|
ckpt_dir=ckpt_dir,
|
||||||
max_seq_len=config.max_seq_len,
|
max_seq_len=config.max_seq_len,
|
||||||
max_batch_size=config.max_batch_size,
|
max_batch_size=config.max_batch_size,
|
||||||
world_size=llama_model.pth_file_count,
|
world_size=config.model_parallel_size or llama_model.pth_file_count,
|
||||||
quantization_mode=quantization_mode,
|
quantization_mode=quantization_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -238,7 +238,7 @@ class Llama3Generator:
|
||||||
ckpt_dir=ckpt_dir,
|
ckpt_dir=ckpt_dir,
|
||||||
max_seq_len=config.max_seq_len,
|
max_seq_len=config.max_seq_len,
|
||||||
max_batch_size=config.max_batch_size,
|
max_batch_size=config.max_batch_size,
|
||||||
world_size=llama_model.pth_file_count,
|
world_size=config.model_parallel_size or llama_model.pth_file_count,
|
||||||
quantization_mode=quantization_mode,
|
quantization_mode=quantization_mode,
|
||||||
)
|
)
|
||||||
self.tokenizer = self.inner_generator.tokenizer
|
self.tokenizer = self.inner_generator.tokenizer
|
||||||
|
|
|
@ -356,50 +356,7 @@
|
||||||
"fairscale",
|
"fairscale",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fbgemm-gpu-genai==1.1.2",
|
||||||
"httpx",
|
|
||||||
"langdetect",
|
|
||||||
"lm-format-enforcer",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentence-transformers",
|
|
||||||
"sentencepiece",
|
|
||||||
"torch",
|
|
||||||
"torchvision",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn",
|
|
||||||
"zmq"
|
|
||||||
],
|
|
||||||
"meta-reference-quantized-gpu": [
|
|
||||||
"accelerate",
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"fairscale",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fbgemm-gpu",
|
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
"langdetect",
|
"langdetect",
|
||||||
|
|
|
@ -18,6 +18,9 @@ providers:
|
||||||
model: ${env.INFERENCE_MODEL}
|
model: ${env.INFERENCE_MODEL}
|
||||||
max_seq_len: 4096
|
max_seq_len: 4096
|
||||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||||
|
quantization:
|
||||||
|
type: ${env.QUANTIZATION_TYPE:bf16}
|
||||||
|
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
|
@ -27,6 +30,9 @@ providers:
|
||||||
model: ${env.SAFETY_MODEL}
|
model: ${env.SAFETY_MODEL}
|
||||||
max_seq_len: 4096
|
max_seq_len: 4096
|
||||||
checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null}
|
checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null}
|
||||||
|
quantization:
|
||||||
|
type: ${env.QUANTIZATION_TYPE:bf16}
|
||||||
|
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null}
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
|
|
@ -18,6 +18,9 @@ providers:
|
||||||
model: ${env.INFERENCE_MODEL}
|
model: ${env.INFERENCE_MODEL}
|
||||||
max_seq_len: 4096
|
max_seq_len: 4096
|
||||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||||
|
quantization:
|
||||||
|
type: ${env.QUANTIZATION_TYPE:bf16}
|
||||||
|
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue