From 19a14cd273f95d5a2b4eea80eb5aa63198b46233 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 13 Sep 2024 16:39:02 -0700 Subject: [PATCH 1/5] Nuke hardware_requirements from SKUs --- llama_toolchain/cli/model/list.py | 4 ---- .../inference/meta_reference/config.py | 16 +++++++++++++++- .../inference/meta_reference/generation.py | 2 +- .../inference/meta_reference/model_parallel.py | 2 +- 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/llama_toolchain/cli/model/list.py b/llama_toolchain/cli/model/list.py index cbbed7e54..f989260ab 100644 --- a/llama_toolchain/cli/model/list.py +++ b/llama_toolchain/cli/model/list.py @@ -38,7 +38,6 @@ class ModelList(Subcommand): "Model Descriptor", "HuggingFace Repo", "Context Length", - "Hardware Requirements", ] rows = [] @@ -46,15 +45,12 @@ class ModelList(Subcommand): if not args.show_all and not model.is_featured: continue - req = model.hardware_requirements - descriptor = model.descriptor() rows.append( [ descriptor, model.huggingface_repo, f"{model.max_seq_length // 1024}K", - f"{req.gpu_count} GPU{'s' if req.gpu_count > 1 else ''}, each >= {req.memory_gb_per_gpu}GB VRAM", ] ) print_table( diff --git a/llama_toolchain/inference/meta_reference/config.py b/llama_toolchain/inference/meta_reference/config.py index d2e601680..a0bbc5820 100644 --- a/llama_toolchain/inference/meta_reference/config.py +++ b/llama_toolchain/inference/meta_reference/config.py @@ -9,7 +9,7 @@ from typing import Optional from llama_models.datatypes import ModelFamily from llama_models.schema_utils import json_schema_type -from llama_models.sku_list import all_registered_models +from llama_models.sku_list import all_registered_models, resolve_model from pydantic import BaseModel, Field, field_validator @@ -41,3 +41,17 @@ class MetaReferenceImplConfig(BaseModel): f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]" ) return model + + @property + def model_parallel_size(self) -> int: + # HUGE HACK ALERT: this will be fixed when we move inference configuration + # to ModelsRegistry and we can explicitly ask for `model_parallel_size` + # as configuration there + gpu_count = 1 + resolved = resolve_model(self.model) + assert resolved is not None + descriptor = resolved.descriptor().lower() + if "-70b" in descriptor or "-405b" in descriptor: + gpu_count = 8 + + return gpu_count diff --git a/llama_toolchain/inference/meta_reference/generation.py b/llama_toolchain/inference/meta_reference/generation.py index 1329f8699..4164dde9e 100644 --- a/llama_toolchain/inference/meta_reference/generation.py +++ b/llama_toolchain/inference/meta_reference/generation.py @@ -79,7 +79,7 @@ class Llama: if not torch.distributed.is_initialized(): torch.distributed.init_process_group("nccl") - model_parallel_size = model.hardware_requirements.gpu_count + model_parallel_size = config.model_parallel_size if not model_parallel_is_initialized(): initialize_model_parallel(model_parallel_size) diff --git a/llama_toolchain/inference/meta_reference/model_parallel.py b/llama_toolchain/inference/meta_reference/model_parallel.py index b5d81287b..833f99efd 100644 --- a/llama_toolchain/inference/meta_reference/model_parallel.py +++ b/llama_toolchain/inference/meta_reference/model_parallel.py @@ -79,7 +79,7 @@ class LlamaModelParallelGenerator: def __enter__(self): self.group = ModelParallelProcessGroup( - self.model.hardware_requirements.gpu_count, + self.config.model_parallel_size, init_model_cb=partial(init_model_cb, self.config), ) self.group.start() From 498cf036173258ee627f5525a77197cd4ab0c927 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 13 Sep 2024 17:04:43 -0700 Subject: [PATCH 2/5] add pypdf --- llama_toolchain/memory/providers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_toolchain/memory/providers.py b/llama_toolchain/memory/providers.py index adfff2e71..40a11235b 100644 --- a/llama_toolchain/memory/providers.py +++ b/llama_toolchain/memory/providers.py @@ -11,7 +11,7 @@ from llama_toolchain.core.datatypes import * # noqa: F403 EMBEDDING_DEPS = [ "blobfile", "chardet", - "PdfReader", + "pypdf", "sentence-transformers", ] From 7a283ea07639cd70c1bd3d1327ce1c649fd38a9d Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 13 Sep 2024 17:23:12 -0700 Subject: [PATCH 3/5] Bump version to 0.0.15 --- requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 106297f09..5502db952 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ blobfile fire httpx huggingface-hub -llama-models>=0.0.14 +llama-models>=0.0.15 pydantic requests termcolor diff --git a/setup.py b/setup.py index 1b0792a20..19913d1e7 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ def read_requirements(): setup( name="llama_toolchain", - version="0.0.14", + version="0.0.15", author="Meta Llama", author_email="llama-oss@meta.com", description="Llama toolchain", From 49ce36426f31f2d2f1d243674b5dc02782f98d3a Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 14 Sep 2024 08:06:34 -0700 Subject: [PATCH 4/5] Make llama model download error message a bit better --- llama_toolchain/cli/download.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_toolchain/cli/download.py b/llama_toolchain/cli/download.py index e2a709c6b..1bfa89fc6 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_toolchain/cli/download.py @@ -48,8 +48,8 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None: ) parser.add_argument( "--model-id", - choices=[x.descriptor() for x in models], required=False, + help="See `llama model list` or `llama model list --show-all` for the list of available models", ) parser.add_argument( "--hf-token", From 53ab18d6bb742b97a899dcbbe728c7cb8bdc098d Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 14 Sep 2024 08:09:45 -0700 Subject: [PATCH 5/5] Bump version to 0.0.16 --- requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 5502db952..45ca7ed06 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ blobfile fire httpx huggingface-hub -llama-models>=0.0.15 +llama-models>=0.0.16 pydantic requests termcolor diff --git a/setup.py b/setup.py index 19913d1e7..7273bee51 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ def read_requirements(): setup( name="llama_toolchain", - version="0.0.15", + version="0.0.16", author="Meta Llama", author_email="llama-oss@meta.com", description="Llama toolchain",