From 803976df26d1db5345b984904a49dcdc87ad1d03 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 3 Aug 2024 20:32:57 -0700 Subject: [PATCH] cleanup, moving stuff to common, nuke utils --- llama_toolchain/cli/distribution/configure.py | 226 +----------------- llama_toolchain/cli/distribution/install.py | 4 +- llama_toolchain/cli/distribution/start.py | 10 +- llama_toolchain/cli/download.py | 5 +- llama_toolchain/cli/inference/configure.py | 91 ------- llama_toolchain/cli/inference/inference.py | 34 --- llama_toolchain/cli/model/describe.py | 2 +- llama_toolchain/common/config_dirs.py | 15 ++ .../distribution/utils.py => common/exec.py} | 2 + llama_toolchain/common/prompt_for_config.py | 224 +++++++++++++++++ .../__init__.py => common/serialize.py} | 10 +- llama_toolchain/inference/inference.py | 2 +- llama_toolchain/utils.py | 34 --- 13 files changed, 263 insertions(+), 396 deletions(-) delete mode 100644 llama_toolchain/cli/inference/configure.py delete mode 100644 llama_toolchain/cli/inference/inference.py create mode 100644 llama_toolchain/common/config_dirs.py rename llama_toolchain/{cli/distribution/utils.py => common/exec.py} (96%) create mode 100644 llama_toolchain/common/prompt_for_config.py rename llama_toolchain/{cli/inference/__init__.py => common/serialize.py} (50%) delete mode 100644 llama_toolchain/utils.py diff --git a/llama_toolchain/cli/distribution/configure.py b/llama_toolchain/cli/distribution/configure.py index a5e2aaa65..bb87d0ddf 100644 --- a/llama_toolchain/cli/distribution/configure.py +++ b/llama_toolchain/cli/distribution/configure.py @@ -5,22 +5,16 @@ # the root directory of this source tree. import argparse -import importlib -import inspect import json import shlex -from enum import Enum from pathlib import Path -from typing import get_args, get_origin, List, Literal, Optional, Union import yaml -from pydantic import BaseModel from termcolor import cprint -from typing_extensions import Annotated from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.utils import DISTRIBS_BASE_DIR, EnumEncoder +from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR class DistributionConfigure(Subcommand): @@ -66,9 +60,11 @@ class DistributionConfigure(Subcommand): def configure_llama_distribution(dist: "Distribution", conda_env: str): + from llama_toolchain.common.exec import run_command + from llama_toolchain.common.prompt_for_config import prompt_for_config + from llama_toolchain.common.serialize import EnumEncoder from llama_toolchain.distribution.datatypes import PassthroughApiAdapter - - from .utils import run_command + from llama_toolchain.distribution.dynamic import instantiate_class_type python_exe = run_command(shlex.split("which python")) # simple check @@ -121,215 +117,3 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str): fp.write(yaml.dump(dist_config, sort_keys=False)) print(f"YAML configuration has been written to {config_path}") - - -def instantiate_class_type(fully_qualified_name): - module_name, class_name = fully_qualified_name.rsplit(".", 1) - module = importlib.import_module(module_name) - return getattr(module, class_name) - - -def is_list_of_primitives(field_type): - """Check if a field type is a List of primitive types.""" - origin = get_origin(field_type) - if origin is List or origin is list: - args = get_args(field_type) - if len(args) == 1 and args[0] in (int, float, str, bool): - return True - return False - - -def get_literal_values(field): - """Extract literal values from a field if it's a Literal type.""" - if get_origin(field.annotation) is Literal: - return get_args(field.annotation) - return None - - -def is_optional(field_type): - """Check if a field type is Optional.""" - return get_origin(field_type) is Union and type(None) in get_args(field_type) - - -def get_non_none_type(field_type): - """Get the non-None type from an Optional type.""" - return next(arg for arg in get_args(field_type) if arg is not type(None)) - - -# TODO: maybe support List values (for simple types, it should be comma-separated and for complex ones, -# it should prompt iteratively if the user wants to add more values) -def prompt_for_config( - config_type: type[BaseModel], existing_config: Optional[BaseModel] = None -) -> BaseModel: - """ - Recursively prompt the user for configuration values based on a Pydantic BaseModel. - - Args: - config_type: A Pydantic BaseModel class representing the configuration structure. - - Returns: - An instance of the config_type with user-provided values. - """ - config_data = {} - - for field_name, field in config_type.__fields__.items(): - field_type = field.annotation - - existing_value = ( - getattr(existing_config, field_name) if existing_config else None - ) - if existing_value: - default_value = existing_value - else: - default_value = ( - field.default if not isinstance(field.default, type(Ellipsis)) else None - ) - is_required = field.required - - # Skip fields with Literal type - if get_origin(field_type) is Literal: - continue - - if inspect.isclass(field_type) and issubclass(field_type, Enum): - prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):" - while True: - # this branch does not handle existing and default values yet - user_input = input(prompt + " ") - try: - config_data[field_name] = field_type[user_input] - break - except KeyError: - print( - f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}" - ) - continue - - # Check if the field is a discriminated union - if get_origin(field_type) is Annotated: - inner_type = get_args(field_type)[0] - if get_origin(inner_type) is Union: - discriminator = field.field_info.discriminator - if discriminator: - union_types = get_args(inner_type) - # Find the discriminator field in each union type - type_map = {} - for t in union_types: - disc_field = t.__fields__[discriminator] - literal_values = get_literal_values(disc_field) - if literal_values: - for value in literal_values: - type_map[value] = t - - while True: - discriminator_value = input( - f"Enter the {discriminator} (options: {', '.join(type_map.keys())}): " - ) - if discriminator_value in type_map: - chosen_type = type_map[discriminator_value] - print(f"\nConfiguring {chosen_type.__name__}:") - - if existing_value and ( - getattr(existing_value, discriminator) - != discriminator_value - ): - existing_value = None - - sub_config = prompt_for_config(chosen_type, existing_value) - config_data[field_name] = sub_config - # Set the discriminator field in the sub-config - setattr(sub_config, discriminator, discriminator_value) - break - else: - print(f"Invalid {discriminator}. Please try again.") - continue - - if ( - is_optional(field_type) - and inspect.isclass(get_non_none_type(field_type)) - and issubclass(get_non_none_type(field_type), BaseModel) - ): - prompt = f"Do you want to configure {field_name}? (y/n): " - if input(prompt).lower() != "y": - config_data[field_name] = None - continue - nested_type = get_non_none_type(field_type) - print(f"Entering sub-configuration for {field_name}:") - config_data[field_name] = prompt_for_config(nested_type, existing_value) - elif inspect.isclass(field_type) and issubclass(field_type, BaseModel): - print(f"\nEntering sub-configuration for {field_name}:") - config_data[field_name] = prompt_for_config( - field_type, - existing_value, - ) - else: - prompt = f"Enter value for {field_name}" - if existing_value is not None: - prompt += f" (existing: {existing_value})" - elif default_value is not None: - prompt += f" (default: {default_value})" - if is_optional(field_type): - prompt += " (optional)" - elif is_required: - prompt += " (required)" - prompt += ": " - - while True: - user_input = input(prompt) - if user_input == "": - if default_value is not None: - config_data[field_name] = default_value - break - elif is_optional(field_type) or not is_required: - config_data[field_name] = None - break - else: - print("This field is required. Please provide a value.") - continue - - try: - # Handle Optional types - if is_optional(field_type): - if user_input.lower() == "none": - config_data[field_name] = None - break - field_type = get_non_none_type(field_type) - - # Handle List of primitives - if is_list_of_primitives(field_type): - try: - value = json.loads(user_input) - if not isinstance(value, list): - raise ValueError("Input must be a JSON-encoded list") - element_type = get_args(field_type)[0] - config_data[field_name] = [ - element_type(item) for item in value - ] - break - except json.JSONDecodeError: - print( - "Invalid JSON. Please enter a valid JSON-encoded list." - ) - continue - except ValueError as e: - print(f"{str(e)}") - continue - - # Convert the input to the correct type - if inspect.isclass(field_type) and issubclass( - field_type, BaseModel - ): - # For nested BaseModels, we assume a dictionary-like string input - import ast - - config_data[field_name] = field_type( - **ast.literal_eval(user_input) - ) - else: - config_data[field_name] = field_type(user_input) - break - except ValueError: - print( - f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}" - ) - - return config_type(**config_data) diff --git a/llama_toolchain/cli/distribution/install.py b/llama_toolchain/cli/distribution/install.py index 8584e7517..1fde93bca 100644 --- a/llama_toolchain/cli/distribution/install.py +++ b/llama_toolchain/cli/distribution/install.py @@ -11,7 +11,7 @@ import shlex import pkg_resources from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.utils import DISTRIBS_BASE_DIR +from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR class DistributionInstall(Subcommand): @@ -46,9 +46,9 @@ class DistributionInstall(Subcommand): ) def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None: + from llama_toolchain.common.exec import run_command, run_with_pty from llama_toolchain.distribution.distribution import distribution_dependencies from llama_toolchain.distribution.registry import resolve_distribution - from .utils import run_command, run_with_pty os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True) script = pkg_resources.resource_filename( diff --git a/llama_toolchain/cli/distribution/start.py b/llama_toolchain/cli/distribution/start.py index 04caeca51..a1dbd9438 100644 --- a/llama_toolchain/cli/distribution/start.py +++ b/llama_toolchain/cli/distribution/start.py @@ -11,7 +11,7 @@ from pathlib import Path import yaml from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.utils import DISTRIBS_BASE_DIR +from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR class DistributionStart(Subcommand): @@ -48,9 +48,9 @@ class DistributionStart(Subcommand): ) def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None: + from llama_toolchain.common.exec import run_command from llama_toolchain.distribution.registry import resolve_distribution from llama_toolchain.distribution.server import main as distribution_server_init - from .utils import run_command dist = resolve_distribution(args.name) if dist is None: @@ -67,6 +67,7 @@ class DistributionStart(Subcommand): config = yaml.safe_load(fp) conda_env = config["conda_env"] + python_exe = run_command(shlex.split("which python")) # simple check, unfortunate if conda_env not in python_exe: @@ -80,8 +81,3 @@ class DistributionStart(Subcommand): args.port, disable_ipv6=args.disable_ipv6, ) - # run_with_pty( - # shlex.split( - # f"conda run -n {conda_env} python -m llama_toolchain.distribution.server {dist.name} {config_yaml} --port 5000" - # ) - # ) diff --git a/llama_toolchain/cli/download.py b/llama_toolchain/cli/download.py index 892af927a..1fa420f4b 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_toolchain/cli/download.py @@ -16,10 +16,7 @@ import httpx from termcolor import cprint from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR - - -DEFAULT_CHECKPOINT_DIR = os.path.join(LLAMA_STACK_CONFIG_DIR, "checkpoints") +from llama_toolchain.common.config_dirs import DEFAULT_CHECKPOINT_DIR class Download(Subcommand): diff --git a/llama_toolchain/cli/inference/configure.py b/llama_toolchain/cli/inference/configure.py deleted file mode 100644 index 1a511ea62..000000000 --- a/llama_toolchain/cli/inference/configure.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import argparse -import os -import textwrap - -from pathlib import Path - -import pkg_resources - -from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR - - -CONFIGS_BASE_DIR = os.path.join(LLAMA_STACK_CONFIG_DIR, "configs") - - -class InferenceConfigure(Subcommand): - """Llama cli for configuring llama toolchain configs""" - - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "configure", - prog="llama inference configure", - description="Configure llama toolchain inference configs", - epilog=textwrap.dedent( - """ - Example: - llama inference configure - """ - ), - formatter_class=argparse.RawTextHelpFormatter, - ) - self._add_arguments() - self.parser.set_defaults(func=self._run_inference_configure_cmd) - - def _add_arguments(self): - pass - - def read_user_inputs(self): - checkpoint_dir = input( - "Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): " - ) - model_parallel_size = input( - "Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): " - ) - assert model_parallel_size.isdigit() and int(model_parallel_size) in { - 1, - 8, - }, "model parallel size must be 1 or 8" - - return checkpoint_dir, model_parallel_size - - def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path): - default_conf_path = pkg_resources.resource_filename( - "llama_toolchain", "data/default_inference_config.yaml" - ) - with open(default_conf_path, "r") as f: - yaml_content = f.read() - - yaml_content = yaml_content.format( - checkpoint_dir=checkpoint_dir, - model_parallel_size=model_parallel_size, - ) - - with open(yaml_output_path, "w") as yaml_file: - yaml_file.write(yaml_content.strip()) - - print(f"YAML configuration has been written to {yaml_output_path}") - - def _run_inference_configure_cmd(self, args: argparse.Namespace) -> None: - checkpoint_dir, model_parallel_size = self.read_user_inputs() - checkpoint_dir = os.path.expanduser(checkpoint_dir) - - assert ( - Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir() - ), f"{checkpoint_dir} does not exist or it not a directory" - - os.makedirs(CONFIGS_BASE_DIR, exist_ok=True) - yaml_output_path = Path(CONFIGS_BASE_DIR) / "inference.yaml" - - self.write_output_yaml( - checkpoint_dir, - model_parallel_size, - yaml_output_path, - ) diff --git a/llama_toolchain/cli/inference/inference.py b/llama_toolchain/cli/inference/inference.py deleted file mode 100644 index 51a82b1f0..000000000 --- a/llama_toolchain/cli/inference/inference.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import argparse -import textwrap - -from llama_toolchain.cli.inference.configure import InferenceConfigure -from llama_toolchain.cli.subcommand import Subcommand - - -class InferenceParser(Subcommand): - """Llama cli for inference apis""" - - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "inference", - prog="llama inference", - description="Run inference on a llama model", - epilog=textwrap.dedent( - """ - Example: - llama inference start - """ - ), - ) - - subparsers = self.parser.add_subparsers(title="inference_subcommands") - - # Add sub-commands - InferenceConfigure.create(subparsers) diff --git a/llama_toolchain/cli/model/describe.py b/llama_toolchain/cli/model/describe.py index a9d02de78..a24fe15f7 100644 --- a/llama_toolchain/cli/model/describe.py +++ b/llama_toolchain/cli/model/describe.py @@ -13,7 +13,7 @@ from termcolor import colored from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.table import print_table -from llama_toolchain.utils import EnumEncoder +from llama_toolchain.common.serialize import EnumEncoder class ModelDescribe(Subcommand): diff --git a/llama_toolchain/common/config_dirs.py b/llama_toolchain/common/config_dirs.py new file mode 100644 index 000000000..8be45b047 --- /dev/null +++ b/llama_toolchain/common/config_dirs.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from pathlib import Path + + +LLAMA_STACK_CONFIG_DIR = os.path.expanduser("~/.llama/") + +DISTRIBS_BASE_DIR = Path(LLAMA_STACK_CONFIG_DIR) / "distributions" + +DEFAULT_CHECKPOINT_DIR = Path(LLAMA_STACK_CONFIG_DIR) / "checkpoints" diff --git a/llama_toolchain/cli/distribution/utils.py b/llama_toolchain/common/exec.py similarity index 96% rename from llama_toolchain/cli/distribution/utils.py rename to llama_toolchain/common/exec.py index 91547ea83..a01a1cf80 100644 --- a/llama_toolchain/cli/distribution/utils.py +++ b/llama_toolchain/common/exec.py @@ -16,6 +16,8 @@ import termios from termcolor import cprint +# run a command in a pseudo-terminal, with interrupt handling, +# useful when you want to run interactive things def run_with_pty(command): master, slave = pty.openpty() diff --git a/llama_toolchain/common/prompt_for_config.py b/llama_toolchain/common/prompt_for_config.py new file mode 100644 index 000000000..c708b96d7 --- /dev/null +++ b/llama_toolchain/common/prompt_for_config.py @@ -0,0 +1,224 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import inspect +import json +from enum import Enum + +from typing import get_args, get_origin, List, Literal, Optional, Union + +from pydantic import BaseModel + +from typing_extensions import Annotated + + +def is_list_of_primitives(field_type): + """Check if a field type is a List of primitive types.""" + origin = get_origin(field_type) + if origin is List or origin is list: + args = get_args(field_type) + if len(args) == 1 and args[0] in (int, float, str, bool): + return True + return False + + +def get_literal_values(field): + """Extract literal values from a field if it's a Literal type.""" + if get_origin(field.annotation) is Literal: + return get_args(field.annotation) + return None + + +def is_optional(field_type): + """Check if a field type is Optional.""" + return get_origin(field_type) is Union and type(None) in get_args(field_type) + + +def get_non_none_type(field_type): + """Get the non-None type from an Optional type.""" + return next(arg for arg in get_args(field_type) if arg is not type(None)) + + +# This is somewhat elaborate, but does not purport to be comprehensive in any way. +# We should add handling for the most common cases to tide us over. +# +# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of +# unit tests for coverage. +def prompt_for_config( + config_type: type[BaseModel], existing_config: Optional[BaseModel] = None +) -> BaseModel: + """ + Recursively prompt the user for configuration values based on a Pydantic BaseModel. + + Args: + config_type: A Pydantic BaseModel class representing the configuration structure. + + Returns: + An instance of the config_type with user-provided values. + """ + config_data = {} + + for field_name, field in config_type.__fields__.items(): + field_type = field.annotation + + existing_value = ( + getattr(existing_config, field_name) if existing_config else None + ) + if existing_value: + default_value = existing_value + else: + default_value = ( + field.default if not isinstance(field.default, type(Ellipsis)) else None + ) + is_required = field.required + + # Skip fields with Literal type + if get_origin(field_type) is Literal: + continue + + if inspect.isclass(field_type) and issubclass(field_type, Enum): + prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):" + while True: + # this branch does not handle existing and default values yet + user_input = input(prompt + " ") + try: + config_data[field_name] = field_type[user_input] + break + except KeyError: + print( + f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}" + ) + continue + + # Check if the field is a discriminated union + if get_origin(field_type) is Annotated: + inner_type = get_args(field_type)[0] + if get_origin(inner_type) is Union: + discriminator = field.field_info.discriminator + if discriminator: + union_types = get_args(inner_type) + # Find the discriminator field in each union type + type_map = {} + for t in union_types: + disc_field = t.__fields__[discriminator] + literal_values = get_literal_values(disc_field) + if literal_values: + for value in literal_values: + type_map[value] = t + + while True: + discriminator_value = input( + f"Enter the {discriminator} (options: {', '.join(type_map.keys())}): " + ) + if discriminator_value in type_map: + chosen_type = type_map[discriminator_value] + print(f"\nConfiguring {chosen_type.__name__}:") + + if existing_value and ( + getattr(existing_value, discriminator) + != discriminator_value + ): + existing_value = None + + sub_config = prompt_for_config(chosen_type, existing_value) + config_data[field_name] = sub_config + # Set the discriminator field in the sub-config + setattr(sub_config, discriminator, discriminator_value) + break + else: + print(f"Invalid {discriminator}. Please try again.") + continue + + if ( + is_optional(field_type) + and inspect.isclass(get_non_none_type(field_type)) + and issubclass(get_non_none_type(field_type), BaseModel) + ): + prompt = f"Do you want to configure {field_name}? (y/n): " + if input(prompt).lower() != "y": + config_data[field_name] = None + continue + nested_type = get_non_none_type(field_type) + print(f"Entering sub-configuration for {field_name}:") + config_data[field_name] = prompt_for_config(nested_type, existing_value) + elif inspect.isclass(field_type) and issubclass(field_type, BaseModel): + print(f"\nEntering sub-configuration for {field_name}:") + config_data[field_name] = prompt_for_config( + field_type, + existing_value, + ) + else: + prompt = f"Enter value for {field_name}" + if existing_value is not None: + prompt += f" (existing: {existing_value})" + elif default_value is not None: + prompt += f" (default: {default_value})" + if is_optional(field_type): + prompt += " (optional)" + elif is_required: + prompt += " (required)" + prompt += ": " + + while True: + user_input = input(prompt) + if user_input == "": + if default_value is not None: + config_data[field_name] = default_value + break + elif is_optional(field_type) or not is_required: + config_data[field_name] = None + break + else: + print("This field is required. Please provide a value.") + continue + + try: + # Handle Optional types + if is_optional(field_type): + if user_input.lower() == "none": + config_data[field_name] = None + break + field_type = get_non_none_type(field_type) + + # Handle List of primitives + if is_list_of_primitives(field_type): + try: + value = json.loads(user_input) + if not isinstance(value, list): + raise ValueError("Input must be a JSON-encoded list") + element_type = get_args(field_type)[0] + config_data[field_name] = [ + element_type(item) for item in value + ] + break + except json.JSONDecodeError: + print( + "Invalid JSON. Please enter a valid JSON-encoded list." + ) + continue + except ValueError as e: + print(f"{str(e)}") + continue + + # Convert the input to the correct type + if inspect.isclass(field_type) and issubclass( + field_type, BaseModel + ): + # For nested BaseModels, we assume a dictionary-like string input + import ast + + config_data[field_name] = field_type( + **ast.literal_eval(user_input) + ) + else: + config_data[field_name] = field_type(user_input) + break + except ValueError: + print( + f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}" + ) + + return config_type(**config_data) diff --git a/llama_toolchain/cli/inference/__init__.py b/llama_toolchain/common/serialize.py similarity index 50% rename from llama_toolchain/cli/inference/__init__.py rename to llama_toolchain/common/serialize.py index 74f5fc120..813851fe9 100644 --- a/llama_toolchain/cli/inference/__init__.py +++ b/llama_toolchain/common/serialize.py @@ -4,4 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .inference import InferenceParser # noqa +import json +from enum import Enum + + +class EnumEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, Enum): + return obj.value + return super().default(obj) diff --git a/llama_toolchain/inference/inference.py b/llama_toolchain/inference/inference.py index bdcbe971e..beeb6dd65 100644 --- a/llama_toolchain/inference/inference.py +++ b/llama_toolchain/inference/inference.py @@ -54,7 +54,7 @@ class MetaReferenceInferenceImpl(Inference): async def initialize(self) -> None: self.generator = LlamaModelParallelGenerator(self.config) - # self.generator.start() + self.generator.start() async def shutdown(self) -> None: self.generator.stop() diff --git a/llama_toolchain/utils.py b/llama_toolchain/utils.py deleted file mode 100644 index 19d6fe976..000000000 --- a/llama_toolchain/utils.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json -import os -from enum import Enum -from pathlib import Path - - -LLAMA_STACK_CONFIG_DIR = os.path.expanduser("~/.llama/") - -DISTRIBS_BASE_DIR = Path(LLAMA_STACK_CONFIG_DIR) / "distributions" - - -def get_root_directory(): - current_dir = os.path.dirname(os.path.abspath(__file__)) - while os.path.isfile(os.path.join(current_dir, "__init__.py")): - current_dir = os.path.dirname(current_dir) - - return current_dir - - -def get_default_config_dir(): - return os.path.join(LLAMA_STACK_CONFIG_DIR, "configs") - - -class EnumEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, Enum): - return obj.value - return super().default(obj)