mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 00:05:18 +00:00
[tmp fix] hardware requirement tmp fix
This commit is contained in:
parent
bafef7ab96
commit
e69e1b8309
2 changed files with 3 additions and 3 deletions
|
@ -28,10 +28,10 @@ from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.llama3.reference_impl.model import Transformer
|
from llama_models.llama3.reference_impl.model import Transformer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_toolchain.common.model_utils import model_local_dir
|
from llama_toolchain.common.model_utils import model_local_dir
|
||||||
from llama_toolchain.inference.api import QuantizationType
|
from llama_toolchain.inference.api import QuantizationType
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from .config import MetaReferenceImplConfig
|
from .config import MetaReferenceImplConfig
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ class Llama:
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
torch.distributed.init_process_group("nccl")
|
torch.distributed.init_process_group("nccl")
|
||||||
|
|
||||||
model_parallel_size = model.hardware_requirements.gpu_count
|
model_parallel_size = 1
|
||||||
if not model_parallel_is_initialized():
|
if not model_parallel_is_initialized():
|
||||||
initialize_model_parallel(model_parallel_size)
|
initialize_model_parallel(model_parallel_size)
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,7 @@ class LlamaModelParallelGenerator:
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.group = ModelParallelProcessGroup(
|
self.group = ModelParallelProcessGroup(
|
||||||
self.model.hardware_requirements.gpu_count,
|
1,
|
||||||
init_model_cb=partial(init_model_cb, self.config),
|
init_model_cb=partial(init_model_cb, self.config),
|
||||||
)
|
)
|
||||||
self.group.start()
|
self.group.start()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue