read existing configuration, save enums properly

This commit is contained in:
Ashwin Bharambe 2024-08-02 13:55:29 -07:00
parent 2cf9915806
commit 3bc827cd5f
3 changed files with 63 additions and 26 deletions

View file

@ -7,10 +7,11 @@
import argparse import argparse
import importlib import importlib
import inspect import inspect
import json
import shlex import shlex
from pathlib import Path from pathlib import Path
from typing import Annotated, get_args, get_origin, Literal, Union from typing import Annotated, get_args, get_origin, Literal, Optional, Union
import yaml import yaml
from pydantic import BaseModel from pydantic import BaseModel
@ -22,7 +23,7 @@ from llama_toolchain.distribution.registry import (
available_distributions, available_distributions,
resolve_distribution, resolve_distribution,
) )
from llama_toolchain.utils import DISTRIBS_BASE_DIR from llama_toolchain.utils import DISTRIBS_BASE_DIR, EnumEncoder
from .utils import run_command from .utils import run_command
@ -73,6 +74,17 @@ def configure_llama_distribution(dist: Distribution, conda_env: str):
f"Please re-run configure by activating the `{conda_env}` conda environment" f"Please re-run configure by activating the `{conda_env}` conda environment"
) )
existing_config = None
config_path = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml"
if config_path.exists():
cprint(
f"Configuration already exists for {dist.name}. Will overwrite...",
"yellow",
attrs=["bold"],
)
with open(config_path, "r") as fp:
existing_config = yaml.safe_load(fp)
adapter_configs = {} adapter_configs = {}
for api_surface, adapter in dist.adapters.items(): for api_surface, adapter in dist.adapters.items():
if isinstance(adapter, PassthroughApiAdapter): if isinstance(adapter, PassthroughApiAdapter):
@ -82,8 +94,14 @@ def configure_llama_distribution(dist: Distribution, conda_env: str):
f"Configuring API surface: {api_surface.value}", "white", attrs=["bold"] f"Configuring API surface: {api_surface.value}", "white", attrs=["bold"]
) )
config_type = instantiate_class_type(adapter.config_class) config_type = instantiate_class_type(adapter.config_class)
# TODO: when we are re-configuring, we should read existing values config = prompt_for_config(
config = prompt_for_config(config_type) config_type,
(
config_type(**existing_config["adapters"][api_surface.value])
if existing_config
else None
),
)
adapter_configs[api_surface.value] = config.dict() adapter_configs[api_surface.value] = config.dict()
dist_config = { dist_config = {
@ -91,11 +109,11 @@ def configure_llama_distribution(dist: Distribution, conda_env: str):
"conda_env": conda_env, "conda_env": conda_env,
} }
yaml_output_path = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml" with open(config_path, "w") as fp:
with open(yaml_output_path, "w") as fp: dist_config = json.loads(json.dumps(dist_config, cls=EnumEncoder))
fp.write(yaml.dump(dist_config, sort_keys=False)) fp.write(yaml.dump(dist_config, sort_keys=False))
print(f"YAML configuration has been written to {yaml_output_path}") print(f"YAML configuration has been written to {config_path}")
def instantiate_class_type(fully_qualified_name): def instantiate_class_type(fully_qualified_name):
@ -121,7 +139,9 @@ def get_non_none_type(field_type):
return next(arg for arg in get_args(field_type) if arg is not type(None)) return next(arg for arg in get_args(field_type) if arg is not type(None))
def prompt_for_config(config_type: type[BaseModel]) -> BaseModel: def prompt_for_config(
config_type: type[BaseModel], existing_config: Optional[BaseModel] = None
) -> BaseModel:
""" """
Recursively prompt the user for configuration values based on a Pydantic BaseModel. Recursively prompt the user for configuration values based on a Pydantic BaseModel.
@ -135,9 +155,16 @@ def prompt_for_config(config_type: type[BaseModel]) -> BaseModel:
for field_name, field in config_type.__fields__.items(): for field_name, field in config_type.__fields__.items():
field_type = field.annotation field_type = field.annotation
default_value = (
field.default if not isinstance(field.default, type(Ellipsis)) else None existing_value = (
getattr(existing_config, field_name) if existing_config else None
) )
if existing_value:
default_value = existing_value
else:
default_value = (
field.default if not isinstance(field.default, type(Ellipsis)) else None
)
is_required = field.required is_required = field.required
# Skip fields with Literal type # Skip fields with Literal type
@ -167,7 +194,14 @@ def prompt_for_config(config_type: type[BaseModel]) -> BaseModel:
if discriminator_value in type_map: if discriminator_value in type_map:
chosen_type = type_map[discriminator_value] chosen_type = type_map[discriminator_value]
print(f"\nConfiguring {chosen_type.__name__}:") print(f"\nConfiguring {chosen_type.__name__}:")
sub_config = prompt_for_config(chosen_type)
if existing_value and (
getattr(existing_value, discriminator)
!= discriminator_value
):
existing_value = None
sub_config = prompt_for_config(chosen_type, existing_value)
config_data[field_name] = sub_config config_data[field_name] = sub_config
# Set the discriminator field in the sub-config # Set the discriminator field in the sub-config
setattr(sub_config, discriminator, discriminator_value) setattr(sub_config, discriminator, discriminator_value)
@ -178,10 +212,15 @@ def prompt_for_config(config_type: type[BaseModel]) -> BaseModel:
if inspect.isclass(field_type) and issubclass(field_type, BaseModel): if inspect.isclass(field_type) and issubclass(field_type, BaseModel):
print(f"\nEntering sub-configuration for {field_name}:") print(f"\nEntering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(field_type) config_data[field_name] = prompt_for_config(
field_type,
existing_value,
)
else: else:
prompt = f"Enter value for {field_name}" prompt = f"Enter value for {field_name}"
if default_value is not None: if existing_value is not None:
prompt += f" (existing: {existing_value})"
elif default_value is not None:
prompt += f" (default: {default_value})" prompt += f" (default: {default_value})"
if is_optional(field_type): if is_optional(field_type):
prompt += " (optional)" prompt += " (optional)"
@ -195,10 +234,7 @@ def prompt_for_config(config_type: type[BaseModel]) -> BaseModel:
if default_value is not None: if default_value is not None:
config_data[field_name] = default_value config_data[field_name] = default_value
break break
elif is_optional(field_type): elif is_optional(field_type) or not is_required:
config_data[field_name] = None
break
elif not is_required:
config_data[field_name] = None config_data[field_name] = None
break break
else: else:

View file

@ -7,21 +7,13 @@
import argparse import argparse
import json import json
from enum import Enum
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from termcolor import colored from termcolor import colored
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.cli.table import print_table from llama_toolchain.cli.table import print_table
from llama_toolchain.utils import EnumEncoder
class EnumEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Enum):
return obj.value
return super().default(obj)
class ModelDescribe(Subcommand): class ModelDescribe(Subcommand):

View file

@ -5,7 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
import getpass import getpass
import json
import os import os
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -65,3 +67,10 @@ def parse_config(config_dir: str, config_path: Optional[str] = None) -> str:
print("------------------------") print("------------------------")
return config return config
class EnumEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Enum):
return obj.value
return super().default(obj)