diff --git a/llama_toolchain/cli/download.py b/llama_toolchain/cli/download.py index e2a709c6b..1bfa89fc6 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_toolchain/cli/download.py @@ -48,8 +48,8 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None: ) parser.add_argument( "--model-id", - choices=[x.descriptor() for x in models], required=False, + help="See `llama model list` or `llama model list --show-all` for the list of available models", ) parser.add_argument( "--hf-token", diff --git a/llama_toolchain/cli/model/list.py b/llama_toolchain/cli/model/list.py index cbbed7e54..f989260ab 100644 --- a/llama_toolchain/cli/model/list.py +++ b/llama_toolchain/cli/model/list.py @@ -38,7 +38,6 @@ class ModelList(Subcommand): "Model Descriptor", "HuggingFace Repo", "Context Length", - "Hardware Requirements", ] rows = [] @@ -46,15 +45,12 @@ class ModelList(Subcommand): if not args.show_all and not model.is_featured: continue - req = model.hardware_requirements - descriptor = model.descriptor() rows.append( [ descriptor, model.huggingface_repo, f"{model.max_seq_length // 1024}K", - f"{req.gpu_count} GPU{'s' if req.gpu_count > 1 else ''}, each >= {req.memory_gb_per_gpu}GB VRAM", ] ) print_table( diff --git a/llama_toolchain/inference/meta_reference/config.py b/llama_toolchain/inference/meta_reference/config.py index d2e601680..a0bbc5820 100644 --- a/llama_toolchain/inference/meta_reference/config.py +++ b/llama_toolchain/inference/meta_reference/config.py @@ -9,7 +9,7 @@ from typing import Optional from llama_models.datatypes import ModelFamily from llama_models.schema_utils import json_schema_type -from llama_models.sku_list import all_registered_models +from llama_models.sku_list import all_registered_models, resolve_model from pydantic import BaseModel, Field, field_validator @@ -41,3 +41,17 @@ class MetaReferenceImplConfig(BaseModel): f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]" ) return model + + @property + def model_parallel_size(self) -> int: + # HUGE HACK ALERT: this will be fixed when we move inference configuration + # to ModelsRegistry and we can explicitly ask for `model_parallel_size` + # as configuration there + gpu_count = 1 + resolved = resolve_model(self.model) + assert resolved is not None + descriptor = resolved.descriptor().lower() + if "-70b" in descriptor or "-405b" in descriptor: + gpu_count = 8 + + return gpu_count diff --git a/llama_toolchain/inference/meta_reference/generation.py b/llama_toolchain/inference/meta_reference/generation.py index 068b3a125..d13b9570d 100644 --- a/llama_toolchain/inference/meta_reference/generation.py +++ b/llama_toolchain/inference/meta_reference/generation.py @@ -79,7 +79,8 @@ class Llama: if not torch.distributed.is_initialized(): torch.distributed.init_process_group("nccl") - model_parallel_size = 1 + model_parallel_size = config.model_parallel_size + if not model_parallel_is_initialized(): initialize_model_parallel(model_parallel_size) diff --git a/llama_toolchain/inference/meta_reference/model_parallel.py b/llama_toolchain/inference/meta_reference/model_parallel.py index 0375fe7ab..833f99efd 100644 --- a/llama_toolchain/inference/meta_reference/model_parallel.py +++ b/llama_toolchain/inference/meta_reference/model_parallel.py @@ -79,7 +79,7 @@ class LlamaModelParallelGenerator: def __enter__(self): self.group = ModelParallelProcessGroup( - 1, + self.config.model_parallel_size, init_model_cb=partial(init_model_cb, self.config), ) self.group.start() diff --git a/requirements.txt b/requirements.txt index 106297f09..45ca7ed06 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ blobfile fire httpx huggingface-hub -llama-models>=0.0.14 +llama-models>=0.0.16 pydantic requests termcolor diff --git a/setup.py b/setup.py index 1b0792a20..7273bee51 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ def read_requirements(): setup( name="llama_toolchain", - version="0.0.14", + version="0.0.16", author="Meta Llama", author_email="llama-oss@meta.com", description="Llama toolchain",