[tmp fix] hardware requirement tmp fix

This commit is contained in:
Xi Yan 2024-09-14 14:06:36 -07:00
parent bafef7ab96
commit e69e1b8309
2 changed files with 3 additions and 3 deletions

View file

@ -79,7 +79,7 @@ class LlamaModelParallelGenerator:
def __enter__(self):
self.group = ModelParallelProcessGroup(
self.model.hardware_requirements.gpu_count,
1,
init_model_cb=partial(init_model_cb, self.config),
)
self.group.start()