Nuke hardware_requirements from SKUs

This commit is contained in:
Ashwin Bharambe 2024-09-13 16:39:02 -07:00
parent d8b3fdbd54
commit 19a14cd273
4 changed files with 17 additions and 7 deletions

View file

@ -38,7 +38,6 @@ class ModelList(Subcommand):
"Model Descriptor", "Model Descriptor",
"HuggingFace Repo", "HuggingFace Repo",
"Context Length", "Context Length",
"Hardware Requirements",
] ]
rows = [] rows = []
@ -46,15 +45,12 @@ class ModelList(Subcommand):
if not args.show_all and not model.is_featured: if not args.show_all and not model.is_featured:
continue continue
req = model.hardware_requirements
descriptor = model.descriptor() descriptor = model.descriptor()
rows.append( rows.append(
[ [
descriptor, descriptor,
model.huggingface_repo, model.huggingface_repo,
f"{model.max_seq_length // 1024}K", f"{model.max_seq_length // 1024}K",
f"{req.gpu_count} GPU{'s' if req.gpu_count > 1 else ''}, each >= {req.memory_gb_per_gpu}GB VRAM",
] ]
) )
print_table( print_table(

View file

@ -9,7 +9,7 @@ from typing import Optional
from llama_models.datatypes import ModelFamily from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models, resolve_model
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -41,3 +41,17 @@ class MetaReferenceImplConfig(BaseModel):
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]" f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
) )
return model return model
@property
def model_parallel_size(self) -> int:
# HUGE HACK ALERT: this will be fixed when we move inference configuration
# to ModelsRegistry and we can explicitly ask for `model_parallel_size`
# as configuration there
gpu_count = 1
resolved = resolve_model(self.model)
assert resolved is not None
descriptor = resolved.descriptor().lower()
if "-70b" in descriptor or "-405b" in descriptor:
gpu_count = 8
return gpu_count

View file

@ -79,7 +79,7 @@ class Llama:
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl") torch.distributed.init_process_group("nccl")
model_parallel_size = model.hardware_requirements.gpu_count model_parallel_size = config.model_parallel_size
if not model_parallel_is_initialized(): if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size) initialize_model_parallel(model_parallel_size)

View file

@ -79,7 +79,7 @@ class LlamaModelParallelGenerator:
def __enter__(self): def __enter__(self):
self.group = ModelParallelProcessGroup( self.group = ModelParallelProcessGroup(
self.model.hardware_requirements.gpu_count, self.config.model_parallel_size,
init_model_cb=partial(init_model_cb, self.config), init_model_cb=partial(init_model_cb, self.config),
) )
self.group.start() self.group.start()