[tmp fix] hardware requirement tmp fix

This commit is contained in:
Xi Yan 2024-09-14 14:06:36 -07:00
parent bafef7ab96
commit e69e1b8309
2 changed files with 3 additions and 3 deletions

View file

@ -28,10 +28,10 @@ from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.sku_list import resolve_model
from termcolor import cprint
from llama_toolchain.common.model_utils import model_local_dir
from llama_toolchain.inference.api import QuantizationType
from termcolor import cprint
from .config import MetaReferenceImplConfig
@ -79,7 +79,7 @@ class Llama:
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
model_parallel_size = model.hardware_requirements.gpu_count
model_parallel_size = 1
if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size)

View file

@ -79,7 +79,7 @@ class LlamaModelParallelGenerator:
def __enter__(self):
self.group = ModelParallelProcessGroup(
self.model.hardware_requirements.gpu_count,
1,
init_model_cb=partial(init_model_cb, self.config),
)
self.group.start()