Nuke hardware_requirements from SKUs

This commit is contained in:
Ashwin Bharambe 2024-09-13 16:39:02 -07:00
parent d8b3fdbd54
commit 19a14cd273
4 changed files with 17 additions and 7 deletions

View file

@ -9,7 +9,7 @@ from typing import Optional
from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type
from llama_models.sku_list import all_registered_models
from llama_models.sku_list import all_registered_models, resolve_model
from pydantic import BaseModel, Field, field_validator
@ -41,3 +41,17 @@ class MetaReferenceImplConfig(BaseModel):
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
return model
@property
def model_parallel_size(self) -> int:
# HUGE HACK ALERT: this will be fixed when we move inference configuration
# to ModelsRegistry and we can explicitly ask for `model_parallel_size`
# as configuration there
gpu_count = 1
resolved = resolve_model(self.model)
assert resolved is not None
descriptor = resolved.descriptor().lower()
if "-70b" in descriptor or "-405b" in descriptor:
gpu_count = 8
return gpu_count