use default_config file to configure inference

This commit is contained in:
Hardik Shah 2024-07-21 19:26:11 -07:00
parent c64b8cba22
commit d95f5f863d
3 changed files with 25 additions and 15 deletions

View file

@ -1,5 +1,6 @@
accelerate accelerate
black==24.4.2 black==24.4.2
blobfile
codeshield codeshield
fairscale fairscale
fastapi fastapi

View file

@ -41,22 +41,22 @@ class InferenceConfigure(Subcommand):
return checkpoint_dir, model_parallel_size return checkpoint_dir, model_parallel_size
def write_output_yaml( def write_output_yaml(
self, self,
checkpoint_dir, checkpoint_dir,
model_parallel_size, model_parallel_size,
yaml_output_path yaml_output_path
): ):
yaml_content = textwrap.dedent(f""" current_dir = os.path.dirname(os.path.abspath(__file__))
inference_config: default_conf_path = os.path.join(current_dir, "default_configuration.yaml")
impl_type: "inline"
inline_config: with open(default_conf_path, "r") as f:
checkpoint_type: "pytorch" yaml_content = f.read()
checkpoint_dir: {checkpoint_dir}/
tokenizer_path: {checkpoint_dir}/tokenizer.model yaml_content = yaml_content.format(
model_parallel_size: {model_parallel_size} checkpoint_dir=checkpoint_dir,
max_seq_len: 2048 model_parallel_size=model_parallel_size,
max_batch_size: 1 )
""")
with open(yaml_output_path, 'w') as yaml_file: with open(yaml_output_path, 'w') as yaml_file:
yaml_file.write(yaml_content.strip()) yaml_file.write(yaml_content.strip())
@ -65,7 +65,7 @@ class InferenceConfigure(Subcommand):
def _run_inference_configure_cmd(self, args: argparse.Namespace) -> None: def _run_inference_configure_cmd(self, args: argparse.Namespace) -> None:
checkpoint_dir, model_parallel_size = self.read_user_inputs() checkpoint_dir, model_parallel_size = self.read_user_inputs()
checkpoint_dir = os.path.expanduser(checkpoint_dir) checkpoint_dir = os.path.expanduser(checkpoint_dir)
if not ( if not (
checkpoint_dir.endswith("original") or checkpoint_dir.endswith("original") or
checkpoint_dir.endswith("original/") checkpoint_dir.endswith("original/")

View file

@ -0,0 +1,9 @@
inference_config:
impl_type: "inline"
inline_config:
checkpoint_type: "pytorch"
checkpoint_dir: {checkpoint_dir}/
tokenizer_path: {checkpoint_dir}/tokenizer.model
model_parallel_size: {model_parallel_size}
max_seq_len: 2048
max_batch_size: 1