read existing configuration, save enums properly

This commit is contained in:
Ashwin Bharambe 2024-08-02 13:55:29 -07:00
parent 2cf9915806
commit 3bc827cd5f
3 changed files with 63 additions and 26 deletions

View file

@ -7,10 +7,11 @@
import argparse
import importlib
import inspect
import json
import shlex
from pathlib import Path
from typing import Annotated, get_args, get_origin, Literal, Union
from typing import Annotated, get_args, get_origin, Literal, Optional, Union
import yaml
from pydantic import BaseModel
@ -22,7 +23,7 @@ from llama_toolchain.distribution.registry import (
available_distributions,
resolve_distribution,
)
from llama_toolchain.utils import DISTRIBS_BASE_DIR
from llama_toolchain.utils import DISTRIBS_BASE_DIR, EnumEncoder
from .utils import run_command
@ -73,6 +74,17 @@ def configure_llama_distribution(dist: Distribution, conda_env: str):
f"Please re-run configure by activating the `{conda_env}` conda environment"
)
existing_config = None
config_path = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml"
if config_path.exists():
cprint(
f"Configuration already exists for {dist.name}. Will overwrite...",
"yellow",
attrs=["bold"],
)
with open(config_path, "r") as fp:
existing_config = yaml.safe_load(fp)
adapter_configs = {}
for api_surface, adapter in dist.adapters.items():
if isinstance(adapter, PassthroughApiAdapter):
@ -82,8 +94,14 @@ def configure_llama_distribution(dist: Distribution, conda_env: str):
f"Configuring API surface: {api_surface.value}", "white", attrs=["bold"]
)
config_type = instantiate_class_type(adapter.config_class)
# TODO: when we are re-configuring, we should read existing values
config = prompt_for_config(config_type)
config = prompt_for_config(
config_type,
(
config_type(**existing_config["adapters"][api_surface.value])
if existing_config
else None
),
)
adapter_configs[api_surface.value] = config.dict()
dist_config = {
@ -91,11 +109,11 @@ def configure_llama_distribution(dist: Distribution, conda_env: str):
"conda_env": conda_env,
}
yaml_output_path = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml"
with open(yaml_output_path, "w") as fp:
with open(config_path, "w") as fp:
dist_config = json.loads(json.dumps(dist_config, cls=EnumEncoder))
fp.write(yaml.dump(dist_config, sort_keys=False))
print(f"YAML configuration has been written to {yaml_output_path}")
print(f"YAML configuration has been written to {config_path}")
def instantiate_class_type(fully_qualified_name):
@ -121,7 +139,9 @@ def get_non_none_type(field_type):
return next(arg for arg in get_args(field_type) if arg is not type(None))
def prompt_for_config(config_type: type[BaseModel]) -> BaseModel:
def prompt_for_config(
config_type: type[BaseModel], existing_config: Optional[BaseModel] = None
) -> BaseModel:
"""
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
@ -135,9 +155,16 @@ def prompt_for_config(config_type: type[BaseModel]) -> BaseModel:
for field_name, field in config_type.__fields__.items():
field_type = field.annotation
default_value = (
field.default if not isinstance(field.default, type(Ellipsis)) else None
existing_value = (
getattr(existing_config, field_name) if existing_config else None
)
if existing_value:
default_value = existing_value
else:
default_value = (
field.default if not isinstance(field.default, type(Ellipsis)) else None
)
is_required = field.required
# Skip fields with Literal type
@ -167,7 +194,14 @@ def prompt_for_config(config_type: type[BaseModel]) -> BaseModel:
if discriminator_value in type_map:
chosen_type = type_map[discriminator_value]
print(f"\nConfiguring {chosen_type.__name__}:")
sub_config = prompt_for_config(chosen_type)
if existing_value and (
getattr(existing_value, discriminator)
!= discriminator_value
):
existing_value = None
sub_config = prompt_for_config(chosen_type, existing_value)
config_data[field_name] = sub_config
# Set the discriminator field in the sub-config
setattr(sub_config, discriminator, discriminator_value)
@ -178,10 +212,15 @@ def prompt_for_config(config_type: type[BaseModel]) -> BaseModel:
if inspect.isclass(field_type) and issubclass(field_type, BaseModel):
print(f"\nEntering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(field_type)
config_data[field_name] = prompt_for_config(
field_type,
existing_value,
)
else:
prompt = f"Enter value for {field_name}"
if default_value is not None:
if existing_value is not None:
prompt += f" (existing: {existing_value})"
elif default_value is not None:
prompt += f" (default: {default_value})"
if is_optional(field_type):
prompt += " (optional)"
@ -195,10 +234,7 @@ def prompt_for_config(config_type: type[BaseModel]) -> BaseModel:
if default_value is not None:
config_data[field_name] = default_value
break
elif is_optional(field_type):
config_data[field_name] = None
break
elif not is_required:
elif is_optional(field_type) or not is_required:
config_data[field_name] = None
break
else:

View file

@ -7,21 +7,13 @@
import argparse
import json
from enum import Enum
from llama_models.sku_list import resolve_model
from termcolor import colored
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.cli.table import print_table
class EnumEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Enum):
return obj.value
return super().default(obj)
from llama_toolchain.utils import EnumEncoder
class ModelDescribe(Subcommand):

View file

@ -5,7 +5,9 @@
# the root directory of this source tree.
import getpass
import json
import os
from enum import Enum
from pathlib import Path
from typing import Optional
@ -65,3 +67,10 @@ def parse_config(config_dir: str, config_path: Optional[str] = None) -> str:
print("------------------------")
return config
class EnumEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Enum):
return obj.value
return super().default(obj)