cleanup, moving stuff to common, nuke utils

This commit is contained in:
Ashwin Bharambe 2024-08-03 20:32:57 -07:00
parent fe582a739d
commit 803976df26
13 changed files with 263 additions and 396 deletions

View file

@ -5,22 +5,16 @@
# the root directory of this source tree. # the root directory of this source tree.
import argparse import argparse
import importlib
import inspect
import json import json
import shlex import shlex
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import get_args, get_origin, List, Literal, Optional, Union
import yaml import yaml
from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from typing_extensions import Annotated
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.utils import DISTRIBS_BASE_DIR, EnumEncoder from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
class DistributionConfigure(Subcommand): class DistributionConfigure(Subcommand):
@ -66,9 +60,11 @@ class DistributionConfigure(Subcommand):
def configure_llama_distribution(dist: "Distribution", conda_env: str): def configure_llama_distribution(dist: "Distribution", conda_env: str):
from llama_toolchain.common.exec import run_command
from llama_toolchain.common.prompt_for_config import prompt_for_config
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.distribution.datatypes import PassthroughApiAdapter from llama_toolchain.distribution.datatypes import PassthroughApiAdapter
from llama_toolchain.distribution.dynamic import instantiate_class_type
from .utils import run_command
python_exe = run_command(shlex.split("which python")) python_exe = run_command(shlex.split("which python"))
# simple check # simple check
@ -121,215 +117,3 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
fp.write(yaml.dump(dist_config, sort_keys=False)) fp.write(yaml.dump(dist_config, sort_keys=False))
print(f"YAML configuration has been written to {config_path}") print(f"YAML configuration has been written to {config_path}")
def instantiate_class_type(fully_qualified_name):
module_name, class_name = fully_qualified_name.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
def is_list_of_primitives(field_type):
"""Check if a field type is a List of primitive types."""
origin = get_origin(field_type)
if origin is List or origin is list:
args = get_args(field_type)
if len(args) == 1 and args[0] in (int, float, str, bool):
return True
return False
def get_literal_values(field):
"""Extract literal values from a field if it's a Literal type."""
if get_origin(field.annotation) is Literal:
return get_args(field.annotation)
return None
def is_optional(field_type):
"""Check if a field type is Optional."""
return get_origin(field_type) is Union and type(None) in get_args(field_type)
def get_non_none_type(field_type):
"""Get the non-None type from an Optional type."""
return next(arg for arg in get_args(field_type) if arg is not type(None))
# TODO: maybe support List values (for simple types, it should be comma-separated and for complex ones,
# it should prompt iteratively if the user wants to add more values)
def prompt_for_config(
config_type: type[BaseModel], existing_config: Optional[BaseModel] = None
) -> BaseModel:
"""
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
Args:
config_type: A Pydantic BaseModel class representing the configuration structure.
Returns:
An instance of the config_type with user-provided values.
"""
config_data = {}
for field_name, field in config_type.__fields__.items():
field_type = field.annotation
existing_value = (
getattr(existing_config, field_name) if existing_config else None
)
if existing_value:
default_value = existing_value
else:
default_value = (
field.default if not isinstance(field.default, type(Ellipsis)) else None
)
is_required = field.required
# Skip fields with Literal type
if get_origin(field_type) is Literal:
continue
if inspect.isclass(field_type) and issubclass(field_type, Enum):
prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):"
while True:
# this branch does not handle existing and default values yet
user_input = input(prompt + " ")
try:
config_data[field_name] = field_type[user_input]
break
except KeyError:
print(
f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}"
)
continue
# Check if the field is a discriminated union
if get_origin(field_type) is Annotated:
inner_type = get_args(field_type)[0]
if get_origin(inner_type) is Union:
discriminator = field.field_info.discriminator
if discriminator:
union_types = get_args(inner_type)
# Find the discriminator field in each union type
type_map = {}
for t in union_types:
disc_field = t.__fields__[discriminator]
literal_values = get_literal_values(disc_field)
if literal_values:
for value in literal_values:
type_map[value] = t
while True:
discriminator_value = input(
f"Enter the {discriminator} (options: {', '.join(type_map.keys())}): "
)
if discriminator_value in type_map:
chosen_type = type_map[discriminator_value]
print(f"\nConfiguring {chosen_type.__name__}:")
if existing_value and (
getattr(existing_value, discriminator)
!= discriminator_value
):
existing_value = None
sub_config = prompt_for_config(chosen_type, existing_value)
config_data[field_name] = sub_config
# Set the discriminator field in the sub-config
setattr(sub_config, discriminator, discriminator_value)
break
else:
print(f"Invalid {discriminator}. Please try again.")
continue
if (
is_optional(field_type)
and inspect.isclass(get_non_none_type(field_type))
and issubclass(get_non_none_type(field_type), BaseModel)
):
prompt = f"Do you want to configure {field_name}? (y/n): "
if input(prompt).lower() != "y":
config_data[field_name] = None
continue
nested_type = get_non_none_type(field_type)
print(f"Entering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(nested_type, existing_value)
elif inspect.isclass(field_type) and issubclass(field_type, BaseModel):
print(f"\nEntering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(
field_type,
existing_value,
)
else:
prompt = f"Enter value for {field_name}"
if existing_value is not None:
prompt += f" (existing: {existing_value})"
elif default_value is not None:
prompt += f" (default: {default_value})"
if is_optional(field_type):
prompt += " (optional)"
elif is_required:
prompt += " (required)"
prompt += ": "
while True:
user_input = input(prompt)
if user_input == "":
if default_value is not None:
config_data[field_name] = default_value
break
elif is_optional(field_type) or not is_required:
config_data[field_name] = None
break
else:
print("This field is required. Please provide a value.")
continue
try:
# Handle Optional types
if is_optional(field_type):
if user_input.lower() == "none":
config_data[field_name] = None
break
field_type = get_non_none_type(field_type)
# Handle List of primitives
if is_list_of_primitives(field_type):
try:
value = json.loads(user_input)
if not isinstance(value, list):
raise ValueError("Input must be a JSON-encoded list")
element_type = get_args(field_type)[0]
config_data[field_name] = [
element_type(item) for item in value
]
break
except json.JSONDecodeError:
print(
"Invalid JSON. Please enter a valid JSON-encoded list."
)
continue
except ValueError as e:
print(f"{str(e)}")
continue
# Convert the input to the correct type
if inspect.isclass(field_type) and issubclass(
field_type, BaseModel
):
# For nested BaseModels, we assume a dictionary-like string input
import ast
config_data[field_name] = field_type(
**ast.literal_eval(user_input)
)
else:
config_data[field_name] = field_type(user_input)
break
except ValueError:
print(
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
)
return config_type(**config_data)

View file

@ -11,7 +11,7 @@ import shlex
import pkg_resources import pkg_resources
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.utils import DISTRIBS_BASE_DIR from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
class DistributionInstall(Subcommand): class DistributionInstall(Subcommand):
@ -46,9 +46,9 @@ class DistributionInstall(Subcommand):
) )
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None: def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_command, run_with_pty
from llama_toolchain.distribution.distribution import distribution_dependencies from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import resolve_distribution from llama_toolchain.distribution.registry import resolve_distribution
from .utils import run_command, run_with_pty
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True) os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
script = pkg_resources.resource_filename( script = pkg_resources.resource_filename(

View file

@ -11,7 +11,7 @@ from pathlib import Path
import yaml import yaml
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.utils import DISTRIBS_BASE_DIR from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
class DistributionStart(Subcommand): class DistributionStart(Subcommand):
@ -48,9 +48,9 @@ class DistributionStart(Subcommand):
) )
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None: def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_command
from llama_toolchain.distribution.registry import resolve_distribution from llama_toolchain.distribution.registry import resolve_distribution
from llama_toolchain.distribution.server import main as distribution_server_init from llama_toolchain.distribution.server import main as distribution_server_init
from .utils import run_command
dist = resolve_distribution(args.name) dist = resolve_distribution(args.name)
if dist is None: if dist is None:
@ -67,6 +67,7 @@ class DistributionStart(Subcommand):
config = yaml.safe_load(fp) config = yaml.safe_load(fp)
conda_env = config["conda_env"] conda_env = config["conda_env"]
python_exe = run_command(shlex.split("which python")) python_exe = run_command(shlex.split("which python"))
# simple check, unfortunate # simple check, unfortunate
if conda_env not in python_exe: if conda_env not in python_exe:
@ -80,8 +81,3 @@ class DistributionStart(Subcommand):
args.port, args.port,
disable_ipv6=args.disable_ipv6, disable_ipv6=args.disable_ipv6,
) )
# run_with_pty(
# shlex.split(
# f"conda run -n {conda_env} python -m llama_toolchain.distribution.server {dist.name} {config_yaml} --port 5000"
# )
# )

View file

@ -16,10 +16,7 @@ import httpx
from termcolor import cprint from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR from llama_toolchain.common.config_dirs import DEFAULT_CHECKPOINT_DIR
DEFAULT_CHECKPOINT_DIR = os.path.join(LLAMA_STACK_CONFIG_DIR, "checkpoints")
class Download(Subcommand): class Download(Subcommand):

View file

@ -1,91 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
import textwrap
from pathlib import Path
import pkg_resources
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR
CONFIGS_BASE_DIR = os.path.join(LLAMA_STACK_CONFIG_DIR, "configs")
class InferenceConfigure(Subcommand):
"""Llama cli for configuring llama toolchain configs"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"configure",
prog="llama inference configure",
description="Configure llama toolchain inference configs",
epilog=textwrap.dedent(
"""
Example:
llama inference configure
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_inference_configure_cmd)
def _add_arguments(self):
pass
def read_user_inputs(self):
checkpoint_dir = input(
"Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): "
)
model_parallel_size = input(
"Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): "
)
assert model_parallel_size.isdigit() and int(model_parallel_size) in {
1,
8,
}, "model parallel size must be 1 or 8"
return checkpoint_dir, model_parallel_size
def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path):
default_conf_path = pkg_resources.resource_filename(
"llama_toolchain", "data/default_inference_config.yaml"
)
with open(default_conf_path, "r") as f:
yaml_content = f.read()
yaml_content = yaml_content.format(
checkpoint_dir=checkpoint_dir,
model_parallel_size=model_parallel_size,
)
with open(yaml_output_path, "w") as yaml_file:
yaml_file.write(yaml_content.strip())
print(f"YAML configuration has been written to {yaml_output_path}")
def _run_inference_configure_cmd(self, args: argparse.Namespace) -> None:
checkpoint_dir, model_parallel_size = self.read_user_inputs()
checkpoint_dir = os.path.expanduser(checkpoint_dir)
assert (
Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir()
), f"{checkpoint_dir} does not exist or it not a directory"
os.makedirs(CONFIGS_BASE_DIR, exist_ok=True)
yaml_output_path = Path(CONFIGS_BASE_DIR) / "inference.yaml"
self.write_output_yaml(
checkpoint_dir,
model_parallel_size,
yaml_output_path,
)

View file

@ -1,34 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import textwrap
from llama_toolchain.cli.inference.configure import InferenceConfigure
from llama_toolchain.cli.subcommand import Subcommand
class InferenceParser(Subcommand):
"""Llama cli for inference apis"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"inference",
prog="llama inference",
description="Run inference on a llama model",
epilog=textwrap.dedent(
"""
Example:
llama inference start <options>
"""
),
)
subparsers = self.parser.add_subparsers(title="inference_subcommands")
# Add sub-commands
InferenceConfigure.create(subparsers)

View file

@ -13,7 +13,7 @@ from termcolor import colored
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.cli.table import print_table from llama_toolchain.cli.table import print_table
from llama_toolchain.utils import EnumEncoder from llama_toolchain.common.serialize import EnumEncoder
class ModelDescribe(Subcommand): class ModelDescribe(Subcommand):

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from pathlib import Path
LLAMA_STACK_CONFIG_DIR = os.path.expanduser("~/.llama/")
DISTRIBS_BASE_DIR = Path(LLAMA_STACK_CONFIG_DIR) / "distributions"
DEFAULT_CHECKPOINT_DIR = Path(LLAMA_STACK_CONFIG_DIR) / "checkpoints"

View file

@ -16,6 +16,8 @@ import termios
from termcolor import cprint from termcolor import cprint
# run a command in a pseudo-terminal, with interrupt handling,
# useful when you want to run interactive things
def run_with_pty(command): def run_with_pty(command):
master, slave = pty.openpty() master, slave = pty.openpty()

View file

@ -0,0 +1,224 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
import json
from enum import Enum
from typing import get_args, get_origin, List, Literal, Optional, Union
from pydantic import BaseModel
from typing_extensions import Annotated
def is_list_of_primitives(field_type):
"""Check if a field type is a List of primitive types."""
origin = get_origin(field_type)
if origin is List or origin is list:
args = get_args(field_type)
if len(args) == 1 and args[0] in (int, float, str, bool):
return True
return False
def get_literal_values(field):
"""Extract literal values from a field if it's a Literal type."""
if get_origin(field.annotation) is Literal:
return get_args(field.annotation)
return None
def is_optional(field_type):
"""Check if a field type is Optional."""
return get_origin(field_type) is Union and type(None) in get_args(field_type)
def get_non_none_type(field_type):
"""Get the non-None type from an Optional type."""
return next(arg for arg in get_args(field_type) if arg is not type(None))
# This is somewhat elaborate, but does not purport to be comprehensive in any way.
# We should add handling for the most common cases to tide us over.
#
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
# unit tests for coverage.
def prompt_for_config(
config_type: type[BaseModel], existing_config: Optional[BaseModel] = None
) -> BaseModel:
"""
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
Args:
config_type: A Pydantic BaseModel class representing the configuration structure.
Returns:
An instance of the config_type with user-provided values.
"""
config_data = {}
for field_name, field in config_type.__fields__.items():
field_type = field.annotation
existing_value = (
getattr(existing_config, field_name) if existing_config else None
)
if existing_value:
default_value = existing_value
else:
default_value = (
field.default if not isinstance(field.default, type(Ellipsis)) else None
)
is_required = field.required
# Skip fields with Literal type
if get_origin(field_type) is Literal:
continue
if inspect.isclass(field_type) and issubclass(field_type, Enum):
prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):"
while True:
# this branch does not handle existing and default values yet
user_input = input(prompt + " ")
try:
config_data[field_name] = field_type[user_input]
break
except KeyError:
print(
f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}"
)
continue
# Check if the field is a discriminated union
if get_origin(field_type) is Annotated:
inner_type = get_args(field_type)[0]
if get_origin(inner_type) is Union:
discriminator = field.field_info.discriminator
if discriminator:
union_types = get_args(inner_type)
# Find the discriminator field in each union type
type_map = {}
for t in union_types:
disc_field = t.__fields__[discriminator]
literal_values = get_literal_values(disc_field)
if literal_values:
for value in literal_values:
type_map[value] = t
while True:
discriminator_value = input(
f"Enter the {discriminator} (options: {', '.join(type_map.keys())}): "
)
if discriminator_value in type_map:
chosen_type = type_map[discriminator_value]
print(f"\nConfiguring {chosen_type.__name__}:")
if existing_value and (
getattr(existing_value, discriminator)
!= discriminator_value
):
existing_value = None
sub_config = prompt_for_config(chosen_type, existing_value)
config_data[field_name] = sub_config
# Set the discriminator field in the sub-config
setattr(sub_config, discriminator, discriminator_value)
break
else:
print(f"Invalid {discriminator}. Please try again.")
continue
if (
is_optional(field_type)
and inspect.isclass(get_non_none_type(field_type))
and issubclass(get_non_none_type(field_type), BaseModel)
):
prompt = f"Do you want to configure {field_name}? (y/n): "
if input(prompt).lower() != "y":
config_data[field_name] = None
continue
nested_type = get_non_none_type(field_type)
print(f"Entering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(nested_type, existing_value)
elif inspect.isclass(field_type) and issubclass(field_type, BaseModel):
print(f"\nEntering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(
field_type,
existing_value,
)
else:
prompt = f"Enter value for {field_name}"
if existing_value is not None:
prompt += f" (existing: {existing_value})"
elif default_value is not None:
prompt += f" (default: {default_value})"
if is_optional(field_type):
prompt += " (optional)"
elif is_required:
prompt += " (required)"
prompt += ": "
while True:
user_input = input(prompt)
if user_input == "":
if default_value is not None:
config_data[field_name] = default_value
break
elif is_optional(field_type) or not is_required:
config_data[field_name] = None
break
else:
print("This field is required. Please provide a value.")
continue
try:
# Handle Optional types
if is_optional(field_type):
if user_input.lower() == "none":
config_data[field_name] = None
break
field_type = get_non_none_type(field_type)
# Handle List of primitives
if is_list_of_primitives(field_type):
try:
value = json.loads(user_input)
if not isinstance(value, list):
raise ValueError("Input must be a JSON-encoded list")
element_type = get_args(field_type)[0]
config_data[field_name] = [
element_type(item) for item in value
]
break
except json.JSONDecodeError:
print(
"Invalid JSON. Please enter a valid JSON-encoded list."
)
continue
except ValueError as e:
print(f"{str(e)}")
continue
# Convert the input to the correct type
if inspect.isclass(field_type) and issubclass(
field_type, BaseModel
):
# For nested BaseModels, we assume a dictionary-like string input
import ast
config_data[field_name] = field_type(
**ast.literal_eval(user_input)
)
else:
config_data[field_name] = field_type(user_input)
break
except ValueError:
print(
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
)
return config_type(**config_data)

View file

@ -4,4 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .inference import InferenceParser # noqa import json
from enum import Enum
class EnumEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Enum):
return obj.value
return super().default(obj)

View file

@ -54,7 +54,7 @@ class MetaReferenceInferenceImpl(Inference):
async def initialize(self) -> None: async def initialize(self) -> None:
self.generator = LlamaModelParallelGenerator(self.config) self.generator = LlamaModelParallelGenerator(self.config)
# self.generator.start() self.generator.start()
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.generator.stop() self.generator.stop()

View file

@ -1,34 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
from enum import Enum
from pathlib import Path
LLAMA_STACK_CONFIG_DIR = os.path.expanduser("~/.llama/")
DISTRIBS_BASE_DIR = Path(LLAMA_STACK_CONFIG_DIR) / "distributions"
def get_root_directory():
current_dir = os.path.dirname(os.path.abspath(__file__))
while os.path.isfile(os.path.join(current_dir, "__init__.py")):
current_dir = os.path.dirname(current_dir)
return current_dir
def get_default_config_dir():
return os.path.join(LLAMA_STACK_CONFIG_DIR, "configs")
class EnumEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Enum):
return obj.value
return super().default(obj)