From 63cf5dda5091caf27ac90b9b586597c9b084ebca Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 7 Apr 2025 11:34:28 -0700 Subject: [PATCH] allow changing model parallel size --- .../inline/inference/meta_reference/config.py | 3 ++ .../inference/meta_reference/generators.py | 4 +- llama_stack/templates/dependencies.json | 45 +------------------ .../meta-reference-gpu/run-with-safety.yaml | 6 +++ .../templates/meta-reference-gpu/run.yaml | 3 ++ 5 files changed, 15 insertions(+), 46 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index 8858f8909..7d089effc 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -21,6 +21,7 @@ class MetaReferenceInferenceConfig(BaseModel): torch_seed: Optional[int] = None max_seq_len: int = 4096 max_batch_size: int = 1 + model_parallel_size: Optional[int] = None # when this is False, we assume that the distributed process group is setup by someone # outside of this code (e.g., when run inside `torchrun`). that is useful for clients @@ -50,6 +51,7 @@ class MetaReferenceInferenceConfig(BaseModel): model: str = "Llama3.2-3B-Instruct", checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}", quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}", + model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:null}", **kwargs, ) -> Dict[str, Any]: return { @@ -59,4 +61,5 @@ class MetaReferenceInferenceConfig(BaseModel): "quantization": { "type": quantization_type, }, + "model_parallel_size": model_parallel_size, } diff --git a/llama_stack/providers/inline/inference/meta_reference/generators.py b/llama_stack/providers/inline/inference/meta_reference/generators.py index 5c76dc74a..b820dcbd8 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generators.py +++ b/llama_stack/providers/inline/inference/meta_reference/generators.py @@ -147,7 +147,7 @@ class Llama4Generator: ckpt_dir=ckpt_dir, max_seq_len=config.max_seq_len, max_batch_size=config.max_batch_size, - world_size=llama_model.pth_file_count, + world_size=config.model_parallel_size or llama_model.pth_file_count, quantization_mode=quantization_mode, ) @@ -238,7 +238,7 @@ class Llama3Generator: ckpt_dir=ckpt_dir, max_seq_len=config.max_seq_len, max_batch_size=config.max_batch_size, - world_size=llama_model.pth_file_count, + world_size=config.model_parallel_size or llama_model.pth_file_count, quantization_mode=quantization_mode, ) self.tokenizer = self.inner_generator.tokenizer diff --git a/llama_stack/templates/dependencies.json b/llama_stack/templates/dependencies.json index 931240d37..b8f475cea 100644 --- a/llama_stack/templates/dependencies.json +++ b/llama_stack/templates/dependencies.json @@ -356,50 +356,7 @@ "fairscale", "faiss-cpu", "fastapi", - "fire", - "httpx", - "langdetect", - "lm-format-enforcer", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentence-transformers", - "sentencepiece", - "torch", - "torchvision", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", - "zmq" - ], - "meta-reference-quantized-gpu": [ - "accelerate", - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "fairscale", - "faiss-cpu", - "fastapi", - "fbgemm-gpu", + "fbgemm-gpu-genai==1.1.2", "fire", "httpx", "langdetect", diff --git a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml index 2cf49cc36..8c7bcbc3c 100644 --- a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml +++ b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml @@ -18,6 +18,9 @@ providers: model: ${env.INFERENCE_MODEL} max_seq_len: 4096 checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} + quantization: + type: ${env.QUANTIZATION_TYPE:bf16} + model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null} - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} @@ -27,6 +30,9 @@ providers: model: ${env.SAFETY_MODEL} max_seq_len: 4096 checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null} + quantization: + type: ${env.QUANTIZATION_TYPE:bf16} + model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null} vector_io: - provider_id: faiss provider_type: inline::faiss diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index 964dfafeb..e6c143363 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -18,6 +18,9 @@ providers: model: ${env.INFERENCE_MODEL} max_seq_len: 4096 checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} + quantization: + type: ${env.QUANTIZATION_TYPE:bf16} + model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null} - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {}