Nuke hardware_requirements from SKUs

This commit is contained in:
Ashwin Bharambe 2024-09-13 16:39:02 -07:00
parent d8b3fdbd54
commit 19a14cd273
4 changed files with 17 additions and 7 deletions

View file

@ -38,7 +38,6 @@ class ModelList(Subcommand):
"Model Descriptor",
"HuggingFace Repo",
"Context Length",
"Hardware Requirements",
]
rows = []
@ -46,15 +45,12 @@ class ModelList(Subcommand):
if not args.show_all and not model.is_featured:
continue
req = model.hardware_requirements
descriptor = model.descriptor()
rows.append(
[
descriptor,
model.huggingface_repo,
f"{model.max_seq_length // 1024}K",
f"{req.gpu_count} GPU{'s' if req.gpu_count > 1 else ''}, each >= {req.memory_gb_per_gpu}GB VRAM",
]
)
print_table(

View file

@ -9,7 +9,7 @@ from typing import Optional
from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type
from llama_models.sku_list import all_registered_models
from llama_models.sku_list import all_registered_models, resolve_model
from pydantic import BaseModel, Field, field_validator
@ -41,3 +41,17 @@ class MetaReferenceImplConfig(BaseModel):
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
return model
@property
def model_parallel_size(self) -> int:
# HUGE HACK ALERT: this will be fixed when we move inference configuration
# to ModelsRegistry and we can explicitly ask for `model_parallel_size`
# as configuration there
gpu_count = 1
resolved = resolve_model(self.model)
assert resolved is not None
descriptor = resolved.descriptor().lower()
if "-70b" in descriptor or "-405b" in descriptor:
gpu_count = 8
return gpu_count

View file

@ -79,7 +79,7 @@ class Llama:
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
model_parallel_size = model.hardware_requirements.gpu_count
model_parallel_size = config.model_parallel_size
if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size)

View file

@ -79,7 +79,7 @@ class LlamaModelParallelGenerator:
def __enter__(self):
self.group = ModelParallelProcessGroup(
self.model.hardware_requirements.gpu_count,
self.config.model_parallel_size,
init_model_cb=partial(init_model_cb, self.config),
)
self.group.start()