diff --git a/llama_toolchain/cli/distribution/configure.py b/llama_toolchain/cli/distribution/configure.py index 10a6baf3c..88666f67e 100644 --- a/llama_toolchain/cli/distribution/configure.py +++ b/llama_toolchain/cli/distribution/configure.py @@ -7,10 +7,11 @@ import argparse import importlib import inspect +import json import shlex from pathlib import Path -from typing import Annotated, get_args, get_origin, Literal, Union +from typing import Annotated, get_args, get_origin, Literal, Optional, Union import yaml from pydantic import BaseModel @@ -22,7 +23,7 @@ from llama_toolchain.distribution.registry import ( available_distributions, resolve_distribution, ) -from llama_toolchain.utils import DISTRIBS_BASE_DIR +from llama_toolchain.utils import DISTRIBS_BASE_DIR, EnumEncoder from .utils import run_command @@ -73,6 +74,17 @@ def configure_llama_distribution(dist: Distribution, conda_env: str): f"Please re-run configure by activating the `{conda_env}` conda environment" ) + existing_config = None + config_path = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml" + if config_path.exists(): + cprint( + f"Configuration already exists for {dist.name}. Will overwrite...", + "yellow", + attrs=["bold"], + ) + with open(config_path, "r") as fp: + existing_config = yaml.safe_load(fp) + adapter_configs = {} for api_surface, adapter in dist.adapters.items(): if isinstance(adapter, PassthroughApiAdapter): @@ -82,8 +94,14 @@ def configure_llama_distribution(dist: Distribution, conda_env: str): f"Configuring API surface: {api_surface.value}", "white", attrs=["bold"] ) config_type = instantiate_class_type(adapter.config_class) - # TODO: when we are re-configuring, we should read existing values - config = prompt_for_config(config_type) + config = prompt_for_config( + config_type, + ( + config_type(**existing_config["adapters"][api_surface.value]) + if existing_config + else None + ), + ) adapter_configs[api_surface.value] = config.dict() dist_config = { @@ -91,11 +109,11 @@ def configure_llama_distribution(dist: Distribution, conda_env: str): "conda_env": conda_env, } - yaml_output_path = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml" - with open(yaml_output_path, "w") as fp: + with open(config_path, "w") as fp: + dist_config = json.loads(json.dumps(dist_config, cls=EnumEncoder)) fp.write(yaml.dump(dist_config, sort_keys=False)) - print(f"YAML configuration has been written to {yaml_output_path}") + print(f"YAML configuration has been written to {config_path}") def instantiate_class_type(fully_qualified_name): @@ -121,7 +139,9 @@ def get_non_none_type(field_type): return next(arg for arg in get_args(field_type) if arg is not type(None)) -def prompt_for_config(config_type: type[BaseModel]) -> BaseModel: +def prompt_for_config( + config_type: type[BaseModel], existing_config: Optional[BaseModel] = None +) -> BaseModel: """ Recursively prompt the user for configuration values based on a Pydantic BaseModel. @@ -135,9 +155,16 @@ def prompt_for_config(config_type: type[BaseModel]) -> BaseModel: for field_name, field in config_type.__fields__.items(): field_type = field.annotation - default_value = ( - field.default if not isinstance(field.default, type(Ellipsis)) else None + + existing_value = ( + getattr(existing_config, field_name) if existing_config else None ) + if existing_value: + default_value = existing_value + else: + default_value = ( + field.default if not isinstance(field.default, type(Ellipsis)) else None + ) is_required = field.required # Skip fields with Literal type @@ -167,7 +194,14 @@ def prompt_for_config(config_type: type[BaseModel]) -> BaseModel: if discriminator_value in type_map: chosen_type = type_map[discriminator_value] print(f"\nConfiguring {chosen_type.__name__}:") - sub_config = prompt_for_config(chosen_type) + + if existing_value and ( + getattr(existing_value, discriminator) + != discriminator_value + ): + existing_value = None + + sub_config = prompt_for_config(chosen_type, existing_value) config_data[field_name] = sub_config # Set the discriminator field in the sub-config setattr(sub_config, discriminator, discriminator_value) @@ -178,10 +212,15 @@ def prompt_for_config(config_type: type[BaseModel]) -> BaseModel: if inspect.isclass(field_type) and issubclass(field_type, BaseModel): print(f"\nEntering sub-configuration for {field_name}:") - config_data[field_name] = prompt_for_config(field_type) + config_data[field_name] = prompt_for_config( + field_type, + existing_value, + ) else: prompt = f"Enter value for {field_name}" - if default_value is not None: + if existing_value is not None: + prompt += f" (existing: {existing_value})" + elif default_value is not None: prompt += f" (default: {default_value})" if is_optional(field_type): prompt += " (optional)" @@ -195,10 +234,7 @@ def prompt_for_config(config_type: type[BaseModel]) -> BaseModel: if default_value is not None: config_data[field_name] = default_value break - elif is_optional(field_type): - config_data[field_name] = None - break - elif not is_required: + elif is_optional(field_type) or not is_required: config_data[field_name] = None break else: diff --git a/llama_toolchain/cli/model/describe.py b/llama_toolchain/cli/model/describe.py index e38885814..a9d02de78 100644 --- a/llama_toolchain/cli/model/describe.py +++ b/llama_toolchain/cli/model/describe.py @@ -7,21 +7,13 @@ import argparse import json -from enum import Enum - from llama_models.sku_list import resolve_model from termcolor import colored from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.table import print_table - - -class EnumEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, Enum): - return obj.value - return super().default(obj) +from llama_toolchain.utils import EnumEncoder class ModelDescribe(Subcommand): diff --git a/llama_toolchain/utils.py b/llama_toolchain/utils.py index d0805d901..0b4df3b30 100644 --- a/llama_toolchain/utils.py +++ b/llama_toolchain/utils.py @@ -5,7 +5,9 @@ # the root directory of this source tree. import getpass +import json import os +from enum import Enum from pathlib import Path from typing import Optional @@ -65,3 +67,10 @@ def parse_config(config_dir: str, config_path: Optional[str] = None) -> str: print("------------------------") return config + + +class EnumEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, Enum): + return obj.value + return super().default(obj)