diff --git a/llama_toolchain/cli/download.py b/llama_toolchain/cli/download.py index 233573ed4..d100cee61 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_toolchain/cli/download.py @@ -16,12 +16,13 @@ import httpx from huggingface_hub import snapshot_download from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError -from llama_models.datatypes import CheckpointQuantizationFormat, ModelDefinition +from llama_models.datatypes import Model from llama_models.sku_list import ( - llama3_1_model_list, - llama_meta_folder_path, - llama_meta_pth_size, + all_registered_models, + llama_meta_net_info, + resolve_model, ) +from termcolor import cprint from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.utils import DEFAULT_DUMP_DIR @@ -45,7 +46,7 @@ class Download(Subcommand): self.parser.set_defaults(func=self._run_download_cmd) def _add_arguments(self): - models = llama3_1_model_list() + models = all_registered_models() self.parser.add_argument( "--source", choices=["meta", "huggingface"], @@ -53,7 +54,7 @@ class Download(Subcommand): ) self.parser.add_argument( "--model-id", - choices=[x.sku.value for x in models], + choices=[x.descriptor() for x in models], required=True, ) self.parser.add_argument( @@ -80,12 +81,12 @@ safetensors files to avoid downloading duplicate weights. """, ) - def _hf_download(self, model: ModelDefinition, hf_token: str, ignore_patterns: str): - repo_id = model.huggingface_id + def _hf_download(self, model: Model, hf_token: str, ignore_patterns: str): + repo_id = model.huggingface_repo if repo_id is None: - raise ValueError(f"No repo id found for model {model.sku.value}") + raise ValueError(f"No repo id found for model {model.descriptor()}") - output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.sku.value + output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor() os.makedirs(output_dir, exist_ok=True) try: true_output_dir = snapshot_download( @@ -111,43 +112,37 @@ safetensors files to avoid downloading duplicate weights. print(f"Successfully downloaded model to {true_output_dir}") - def _meta_download(self, model: ModelDefinition, meta_url: str): - output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.sku.value + def _meta_download(self, model: Model, meta_url: str): + output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor() os.makedirs(output_dir, exist_ok=True) - gpus = model.hardware_requirements.gpu_count - files = [ - "tokenizer.model", - "params.json", - ] - if model.quantization_format == CheckpointQuantizationFormat.fp8_mixed: - files.extend([f"fp8_scales_{i}.pt" for i in range(gpus)]) - files.extend([f"consolidated.{i:02d}.pth" for i in range(gpus)]) - - folder_path = llama_meta_folder_path(model) - pth_size = llama_meta_pth_size(model) + info = llama_meta_net_info(model) # I believe we can use some concurrency here if needed but not sure it is worth it - for f in files: + for f in info.files: output_file = str(output_dir / f) - url = meta_url.replace("*", f"{folder_path}/{f}") - total_size = pth_size if "consolidated" in f else 0 + url = meta_url.replace("*", f"{info.folder}/{f}") + total_size = info.pth_size if "consolidated" in f else 0 + cprint(f"Downloading `{f}`...", "white") downloader = ResumableDownloader(url, output_file, total_size) asyncio.run(downloader.download()) def _run_download_cmd(self, args: argparse.Namespace): - by_id = {model.sku.value: model for model in llama3_1_model_list()} - assert args.model_id in by_id, f"Unexpected model id {args.model_id}" + model = resolve_model(args.model_id) + if model is None: + self.parser.error(f"Model {args.model_id} not found") + return - model = by_id[args.model_id] if args.source == "huggingface": self._hf_download(model, args.hf_token, args.ignore_patterns) else: - if not args.meta_url: - self.parser.error( - "Please provide a meta url to download the model from llama.meta.com" + meta_url = args.meta_url + if not meta_url: + meta_url = input( + "Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): " ) - self._meta_download(model, args.meta_url) + assert meta_url is not None and "llama3-1.llamameta.net" in meta_url + self._meta_download(model, meta_url) class ResumableDownloader: @@ -170,7 +165,10 @@ class ResumableDownloader: if self.total_size > 0: return - response = await client.head(self.url, follow_redirects=True) + # Force disable compression when trying to retrieve file size + response = await client.head( + self.url, follow_redirects=True, headers={"Accept-Encoding": "identity"} + ) response.raise_for_status() self.url = str(response.url) # Update URL in case of redirects self.total_size = int(response.headers.get("Content-Length", 0)) diff --git a/llama_toolchain/cli/model/describe.py b/llama_toolchain/cli/model/describe.py index 6551e6e65..e38885814 100644 --- a/llama_toolchain/cli/model/describe.py +++ b/llama_toolchain/cli/model/describe.py @@ -9,7 +9,7 @@ import json from enum import Enum -from llama_models.sku_list import llama3_1_model_list +from llama_models.sku_list import resolve_model from termcolor import colored @@ -47,20 +47,13 @@ class ModelDescribe(Subcommand): ) def _run_model_describe_cmd(self, args: argparse.Namespace) -> None: - models = llama3_1_model_list() - by_id = {model.sku.value: model for model in models} - - if args.model_id not in by_id: - print( + model = resolve_model(args.model_id) + if model is None: + self.parser.error( f"Model {args.model_id} not found; try 'llama model list' for a list of available models." ) return - model = by_id[args.model_id] - - sampling_params = model.recommended_sampling_params.dict() - for k in ("max_tokens", "repetition_penalty"): - del sampling_params[k] rows = [ ( colored("Model", "white", attrs=["bold"]), @@ -70,13 +63,20 @@ class ModelDescribe(Subcommand): ("Description", model.description_markdown), ("Context Length", f"{model.max_seq_length // 1024}K tokens"), ("Weights format", model.quantization_format.value), - ( - "Recommended sampling params", - json.dumps(sampling_params, cls=EnumEncoder, indent=4), - ), ("Model params.json", json.dumps(model.model_args, indent=4)), ] + if model.recommended_sampling_params is not None: + sampling_params = model.recommended_sampling_params.dict() + for k in ("max_tokens", "repetition_penalty"): + del sampling_params[k] + rows.append( + ( + "Recommended sampling params", + json.dumps(sampling_params, cls=EnumEncoder, indent=4), + ) + ) + print_table( rows, separate_rows=True, diff --git a/llama_toolchain/cli/model/list.py b/llama_toolchain/cli/model/list.py index 9d26bb181..c6d4b24ac 100644 --- a/llama_toolchain/cli/model/list.py +++ b/llama_toolchain/cli/model/list.py @@ -6,7 +6,7 @@ import argparse -from llama_models.sku_list import llama3_1_model_list +from llama_models.sku_list import all_registered_models from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.table import print_table @@ -30,21 +30,22 @@ class ModelList(Subcommand): pass def _run_model_list_cmd(self, args: argparse.Namespace) -> None: - models = llama3_1_model_list() headers = [ - "Model ID", - "HuggingFace ID", + "Model Descriptor", + "HuggingFace Repo", "Context Length", "Hardware Requirements", ] rows = [] - for model in models: + for model in all_registered_models(): req = model.hardware_requirements + + descriptor = model.descriptor() rows.append( [ - model.sku.value, - model.huggingface_id, + descriptor, + model.huggingface_repo, f"{model.max_seq_length // 1024}K", f"{req.gpu_count} GPU{'s' if req.gpu_count > 1 else ''}, each >= {req.memory_gb_per_gpu}GB VRAM", ] diff --git a/llama_toolchain/inference/api/endpoints.py b/llama_toolchain/inference/api/endpoints.py index 4ba9d420e..611fc9133 100644 --- a/llama_toolchain/inference/api/endpoints.py +++ b/llama_toolchain/inference/api/endpoints.py @@ -13,7 +13,7 @@ from pyopenapi import webmethod @json_schema_type class CompletionRequest(BaseModel): - model: PretrainedModel + model: str content: InterleavedTextAttachment sampling_params: Optional[SamplingParams] = SamplingParams() @@ -39,7 +39,7 @@ class CompletionResponseStreamChunk(BaseModel): @json_schema_type class BatchCompletionRequest(BaseModel): - model: PretrainedModel + model: str content_batch: List[InterleavedTextAttachment] sampling_params: Optional[SamplingParams] = SamplingParams() logprobs: Optional[LogProbConfig] = None @@ -53,7 +53,7 @@ class BatchCompletionResponse(BaseModel): @json_schema_type class ChatCompletionRequest(BaseModel): - model: InstructModel + model: str messages: List[Message] sampling_params: Optional[SamplingParams] = SamplingParams() @@ -80,7 +80,7 @@ class ChatCompletionResponse(BaseModel): @json_schema_type class BatchChatCompletionRequest(BaseModel): - model: InstructModel + model: str messages_batch: List[List[Message]] sampling_params: Optional[SamplingParams] = SamplingParams() diff --git a/llama_toolchain/inference/client.py b/llama_toolchain/inference/client.py index 3dd646457..4e9dd5ee2 100644 --- a/llama_toolchain/inference/client.py +++ b/llama_toolchain/inference/client.py @@ -18,7 +18,6 @@ from .api import ( ChatCompletionResponseStreamChunk, CompletionRequest, Inference, - InstructModel, UserMessage, ) from .event_logger import EventLogger @@ -67,7 +66,7 @@ async def run_main(host: str, port: int, stream: bool): cprint(f"User>{message.content}", "green") iterator = client.chat_completion( ChatCompletionRequest( - model=InstructModel.llama3_8b_chat, + model="Meta-Llama-3.1-8B-Instruct", messages=[message], stream=stream, ) diff --git a/llama_toolchain/post_training/api/endpoints.py b/llama_toolchain/post_training/api/endpoints.py index 3641e69cf..d8309e94c 100644 --- a/llama_toolchain/post_training/api/endpoints.py +++ b/llama_toolchain/post_training/api/endpoints.py @@ -25,7 +25,7 @@ class PostTrainingSFTRequest(BaseModel): job_uuid: str - model: PretrainedModel + model: str dataset: TrainEvalDataset validation_dataset: TrainEvalDataset diff --git a/llama_toolchain/reward_scoring/api/endpoints.py b/llama_toolchain/reward_scoring/api/endpoints.py index 3f9144b32..375e859a2 100644 --- a/llama_toolchain/reward_scoring/api/endpoints.py +++ b/llama_toolchain/reward_scoring/api/endpoints.py @@ -15,7 +15,7 @@ class RewardScoringRequest(BaseModel): """Request to score a reward function. A list of prompts and a list of responses per prompt.""" dialog_generations: List[DialogGenerations] - model: RewardModel + model: str @json_schema_type diff --git a/llama_toolchain/synthetic_data_generation/api/endpoints.py b/llama_toolchain/synthetic_data_generation/api/endpoints.py index a2f54e9f0..a3b0c6ec6 100644 --- a/llama_toolchain/synthetic_data_generation/api/endpoints.py +++ b/llama_toolchain/synthetic_data_generation/api/endpoints.py @@ -22,7 +22,7 @@ class SyntheticDataGenerationRequest(BaseModel): dialogs: List[Message] filtering_function: FilteringFunction = FilteringFunction.none - model: Optional[RewardModel] = None + model: Optional[str] = None @json_schema_type