From 19a14cd273f95d5a2b4eea80eb5aa63198b46233 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 13 Sep 2024 16:39:02 -0700 Subject: [PATCH] Nuke hardware_requirements from SKUs --- llama_toolchain/cli/model/list.py | 4 ---- .../inference/meta_reference/config.py | 16 +++++++++++++++- .../inference/meta_reference/generation.py | 2 +- .../inference/meta_reference/model_parallel.py | 2 +- 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/llama_toolchain/cli/model/list.py b/llama_toolchain/cli/model/list.py index cbbed7e54..f989260ab 100644 --- a/llama_toolchain/cli/model/list.py +++ b/llama_toolchain/cli/model/list.py @@ -38,7 +38,6 @@ class ModelList(Subcommand): "Model Descriptor", "HuggingFace Repo", "Context Length", - "Hardware Requirements", ] rows = [] @@ -46,15 +45,12 @@ class ModelList(Subcommand): if not args.show_all and not model.is_featured: continue - req = model.hardware_requirements - descriptor = model.descriptor() rows.append( [ descriptor, model.huggingface_repo, f"{model.max_seq_length // 1024}K", - f"{req.gpu_count} GPU{'s' if req.gpu_count > 1 else ''}, each >= {req.memory_gb_per_gpu}GB VRAM", ] ) print_table( diff --git a/llama_toolchain/inference/meta_reference/config.py b/llama_toolchain/inference/meta_reference/config.py index d2e601680..a0bbc5820 100644 --- a/llama_toolchain/inference/meta_reference/config.py +++ b/llama_toolchain/inference/meta_reference/config.py @@ -9,7 +9,7 @@ from typing import Optional from llama_models.datatypes import ModelFamily from llama_models.schema_utils import json_schema_type -from llama_models.sku_list import all_registered_models +from llama_models.sku_list import all_registered_models, resolve_model from pydantic import BaseModel, Field, field_validator @@ -41,3 +41,17 @@ class MetaReferenceImplConfig(BaseModel): f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]" ) return model + + @property + def model_parallel_size(self) -> int: + # HUGE HACK ALERT: this will be fixed when we move inference configuration + # to ModelsRegistry and we can explicitly ask for `model_parallel_size` + # as configuration there + gpu_count = 1 + resolved = resolve_model(self.model) + assert resolved is not None + descriptor = resolved.descriptor().lower() + if "-70b" in descriptor or "-405b" in descriptor: + gpu_count = 8 + + return gpu_count diff --git a/llama_toolchain/inference/meta_reference/generation.py b/llama_toolchain/inference/meta_reference/generation.py index 1329f8699..4164dde9e 100644 --- a/llama_toolchain/inference/meta_reference/generation.py +++ b/llama_toolchain/inference/meta_reference/generation.py @@ -79,7 +79,7 @@ class Llama: if not torch.distributed.is_initialized(): torch.distributed.init_process_group("nccl") - model_parallel_size = model.hardware_requirements.gpu_count + model_parallel_size = config.model_parallel_size if not model_parallel_is_initialized(): initialize_model_parallel(model_parallel_size) diff --git a/llama_toolchain/inference/meta_reference/model_parallel.py b/llama_toolchain/inference/meta_reference/model_parallel.py index b5d81287b..833f99efd 100644 --- a/llama_toolchain/inference/meta_reference/model_parallel.py +++ b/llama_toolchain/inference/meta_reference/model_parallel.py @@ -79,7 +79,7 @@ class LlamaModelParallelGenerator: def __enter__(self): self.group = ModelParallelProcessGroup( - self.model.hardware_requirements.gpu_count, + self.config.model_parallel_size, init_model_cb=partial(init_model_cb, self.config), ) self.group.start()