Merge branch 'main' into cli

This commit is contained in:
Xi Yan 2024-09-14 14:10:34 -07:00 committed by GitHub
commit 085f9fcce3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 21 additions and 10 deletions

View file

@ -48,8 +48,8 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
)
parser.add_argument(
"--model-id",
choices=[x.descriptor() for x in models],
required=False,
help="See `llama model list` or `llama model list --show-all` for the list of available models",
)
parser.add_argument(
"--hf-token",

View file

@ -38,7 +38,6 @@ class ModelList(Subcommand):
"Model Descriptor",
"HuggingFace Repo",
"Context Length",
"Hardware Requirements",
]
rows = []
@ -46,15 +45,12 @@ class ModelList(Subcommand):
if not args.show_all and not model.is_featured:
continue
req = model.hardware_requirements
descriptor = model.descriptor()
rows.append(
[
descriptor,
model.huggingface_repo,
f"{model.max_seq_length // 1024}K",
f"{req.gpu_count} GPU{'s' if req.gpu_count > 1 else ''}, each >= {req.memory_gb_per_gpu}GB VRAM",
]
)
print_table(

View file

@ -9,7 +9,7 @@ from typing import Optional
from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type
from llama_models.sku_list import all_registered_models
from llama_models.sku_list import all_registered_models, resolve_model
from pydantic import BaseModel, Field, field_validator
@ -41,3 +41,17 @@ class MetaReferenceImplConfig(BaseModel):
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
return model
@property
def model_parallel_size(self) -> int:
# HUGE HACK ALERT: this will be fixed when we move inference configuration
# to ModelsRegistry and we can explicitly ask for `model_parallel_size`
# as configuration there
gpu_count = 1
resolved = resolve_model(self.model)
assert resolved is not None
descriptor = resolved.descriptor().lower()
if "-70b" in descriptor or "-405b" in descriptor:
gpu_count = 8
return gpu_count

View file

@ -79,7 +79,8 @@ class Llama:
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
model_parallel_size = 1
model_parallel_size = config.model_parallel_size
if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size)

View file

@ -79,7 +79,7 @@ class LlamaModelParallelGenerator:
def __enter__(self):
self.group = ModelParallelProcessGroup(
1,
self.config.model_parallel_size,
init_model_cb=partial(init_model_cb, self.config),
)
self.group.start()

View file

@ -2,7 +2,7 @@ blobfile
fire
httpx
huggingface-hub
llama-models>=0.0.14
llama-models>=0.0.16
pydantic
requests
termcolor

View file

@ -16,7 +16,7 @@ def read_requirements():
setup(
name="llama_toolchain",
version="0.0.14",
version="0.0.16",
author="Meta Llama",
author_email="llama-oss@meta.com",
description="Llama toolchain",