diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 000000000..27f039be6 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include llama_toolchain/data/*.yaml diff --git a/llama_toolchain/cli/inference/configure.py b/llama_toolchain/cli/inference/configure.py index 33e3c658a..593f02c09 100644 --- a/llama_toolchain/cli/inference/configure.py +++ b/llama_toolchain/cli/inference/configure.py @@ -1,5 +1,6 @@ import argparse import os +import pkg_resources import textwrap from pathlib import Path @@ -37,6 +38,7 @@ class InferenceConfigure(Subcommand): def read_user_inputs(self): checkpoint_dir = input("Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): ") model_parallel_size = input("Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): ") + assert model_parallel_size.isdigit() and int(model_parallel_size) in {1, 8}, "model parallel size must be 1 or 8" return checkpoint_dir, model_parallel_size @@ -46,9 +48,10 @@ class InferenceConfigure(Subcommand): model_parallel_size, yaml_output_path ): - current_dir = os.path.dirname(os.path.abspath(__file__)) - default_conf_path = os.path.join(current_dir, "default_configuration.yaml") - + default_conf_path = pkg_resources.resource_filename( + 'llama_toolchain', + 'data/default_inference_config.yaml' + ) with open(default_conf_path, "r") as f: yaml_content = f.read() diff --git a/llama_toolchain/cli/inference/default_configuration.yaml b/llama_toolchain/data/default_inference_config.yaml similarity index 100% rename from llama_toolchain/cli/inference/default_configuration.yaml rename to llama_toolchain/data/default_inference_config.yaml