mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 08:00:09 +00:00
Merge branch 'main' into cli
This commit is contained in:
commit
085f9fcce3
7 changed files with 21 additions and 10 deletions
|
@ -48,8 +48,8 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model-id",
|
"--model-id",
|
||||||
choices=[x.descriptor() for x in models],
|
|
||||||
required=False,
|
required=False,
|
||||||
|
help="See `llama model list` or `llama model list --show-all` for the list of available models",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hf-token",
|
"--hf-token",
|
||||||
|
|
|
@ -38,7 +38,6 @@ class ModelList(Subcommand):
|
||||||
"Model Descriptor",
|
"Model Descriptor",
|
||||||
"HuggingFace Repo",
|
"HuggingFace Repo",
|
||||||
"Context Length",
|
"Context Length",
|
||||||
"Hardware Requirements",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
rows = []
|
rows = []
|
||||||
|
@ -46,15 +45,12 @@ class ModelList(Subcommand):
|
||||||
if not args.show_all and not model.is_featured:
|
if not args.show_all and not model.is_featured:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
req = model.hardware_requirements
|
|
||||||
|
|
||||||
descriptor = model.descriptor()
|
descriptor = model.descriptor()
|
||||||
rows.append(
|
rows.append(
|
||||||
[
|
[
|
||||||
descriptor,
|
descriptor,
|
||||||
model.huggingface_repo,
|
model.huggingface_repo,
|
||||||
f"{model.max_seq_length // 1024}K",
|
f"{model.max_seq_length // 1024}K",
|
||||||
f"{req.gpu_count} GPU{'s' if req.gpu_count > 1 else ''}, each >= {req.memory_gb_per_gpu}GB VRAM",
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
print_table(
|
print_table(
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import Optional
|
||||||
from llama_models.datatypes import ModelFamily
|
from llama_models.datatypes import ModelFamily
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models, resolve_model
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
@ -41,3 +41,17 @@ class MetaReferenceImplConfig(BaseModel):
|
||||||
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
|
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_parallel_size(self) -> int:
|
||||||
|
# HUGE HACK ALERT: this will be fixed when we move inference configuration
|
||||||
|
# to ModelsRegistry and we can explicitly ask for `model_parallel_size`
|
||||||
|
# as configuration there
|
||||||
|
gpu_count = 1
|
||||||
|
resolved = resolve_model(self.model)
|
||||||
|
assert resolved is not None
|
||||||
|
descriptor = resolved.descriptor().lower()
|
||||||
|
if "-70b" in descriptor or "-405b" in descriptor:
|
||||||
|
gpu_count = 8
|
||||||
|
|
||||||
|
return gpu_count
|
||||||
|
|
|
@ -79,7 +79,8 @@ class Llama:
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
torch.distributed.init_process_group("nccl")
|
torch.distributed.init_process_group("nccl")
|
||||||
|
|
||||||
model_parallel_size = 1
|
model_parallel_size = config.model_parallel_size
|
||||||
|
|
||||||
if not model_parallel_is_initialized():
|
if not model_parallel_is_initialized():
|
||||||
initialize_model_parallel(model_parallel_size)
|
initialize_model_parallel(model_parallel_size)
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,7 @@ class LlamaModelParallelGenerator:
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.group = ModelParallelProcessGroup(
|
self.group = ModelParallelProcessGroup(
|
||||||
1,
|
self.config.model_parallel_size,
|
||||||
init_model_cb=partial(init_model_cb, self.config),
|
init_model_cb=partial(init_model_cb, self.config),
|
||||||
)
|
)
|
||||||
self.group.start()
|
self.group.start()
|
||||||
|
|
|
@ -2,7 +2,7 @@ blobfile
|
||||||
fire
|
fire
|
||||||
httpx
|
httpx
|
||||||
huggingface-hub
|
huggingface-hub
|
||||||
llama-models>=0.0.14
|
llama-models>=0.0.16
|
||||||
pydantic
|
pydantic
|
||||||
requests
|
requests
|
||||||
termcolor
|
termcolor
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -16,7 +16,7 @@ def read_requirements():
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="llama_toolchain",
|
name="llama_toolchain",
|
||||||
version="0.0.14",
|
version="0.0.16",
|
||||||
author="Meta Llama",
|
author="Meta Llama",
|
||||||
author_email="llama-oss@meta.com",
|
author_email="llama-oss@meta.com",
|
||||||
description="Llama toolchain",
|
description="Llama toolchain",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue