From d95f5f863d437eaf2f5cae24aeeb1900392c8944 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Sun, 21 Jul 2024 19:26:11 -0700 Subject: [PATCH] use default_config file to configure inference --- requirements.txt | 1 + toolchain/cli/inference/configure.py | 30 +++++++++---------- .../cli/inference/default_configuration.yaml | 9 ++++++ 3 files changed, 25 insertions(+), 15 deletions(-) create mode 100644 toolchain/cli/inference/default_configuration.yaml diff --git a/requirements.txt b/requirements.txt index 514c3ffa0..3ca6caf4a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ accelerate black==24.4.2 +blobfile codeshield fairscale fastapi diff --git a/toolchain/cli/inference/configure.py b/toolchain/cli/inference/configure.py index 0c0ae61fe..c802bd8f8 100644 --- a/toolchain/cli/inference/configure.py +++ b/toolchain/cli/inference/configure.py @@ -41,22 +41,22 @@ class InferenceConfigure(Subcommand): return checkpoint_dir, model_parallel_size def write_output_yaml( - self, - checkpoint_dir, - model_parallel_size, + self, + checkpoint_dir, + model_parallel_size, yaml_output_path ): - yaml_content = textwrap.dedent(f""" - inference_config: - impl_type: "inline" - inline_config: - checkpoint_type: "pytorch" - checkpoint_dir: {checkpoint_dir}/ - tokenizer_path: {checkpoint_dir}/tokenizer.model - model_parallel_size: {model_parallel_size} - max_seq_len: 2048 - max_batch_size: 1 - """) + current_dir = os.path.dirname(os.path.abspath(__file__)) + default_conf_path = os.path.join(current_dir, "default_configuration.yaml") + + with open(default_conf_path, "r") as f: + yaml_content = f.read() + + yaml_content = yaml_content.format( + checkpoint_dir=checkpoint_dir, + model_parallel_size=model_parallel_size, + ) + with open(yaml_output_path, 'w') as yaml_file: yaml_file.write(yaml_content.strip()) @@ -65,7 +65,7 @@ class InferenceConfigure(Subcommand): def _run_inference_configure_cmd(self, args: argparse.Namespace) -> None: checkpoint_dir, model_parallel_size = self.read_user_inputs() checkpoint_dir = os.path.expanduser(checkpoint_dir) - + if not ( checkpoint_dir.endswith("original") or checkpoint_dir.endswith("original/") diff --git a/toolchain/cli/inference/default_configuration.yaml b/toolchain/cli/inference/default_configuration.yaml new file mode 100644 index 000000000..253e0e143 --- /dev/null +++ b/toolchain/cli/inference/default_configuration.yaml @@ -0,0 +1,9 @@ +inference_config: + impl_type: "inline" + inline_config: + checkpoint_type: "pytorch" + checkpoint_dir: {checkpoint_dir}/ + tokenizer_path: {checkpoint_dir}/tokenizer.model + model_parallel_size: {model_parallel_size} + max_seq_len: 2048 + max_batch_size: 1