mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
read existing configuration, save enums properly
This commit is contained in:
parent
2cf9915806
commit
3bc827cd5f
3 changed files with 63 additions and 26 deletions
|
@ -7,10 +7,11 @@
|
|||
import argparse
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import shlex
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated, get_args, get_origin, Literal, Union
|
||||
from typing import Annotated, get_args, get_origin, Literal, Optional, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
@ -22,7 +23,7 @@ from llama_toolchain.distribution.registry import (
|
|||
available_distributions,
|
||||
resolve_distribution,
|
||||
)
|
||||
from llama_toolchain.utils import DISTRIBS_BASE_DIR
|
||||
from llama_toolchain.utils import DISTRIBS_BASE_DIR, EnumEncoder
|
||||
|
||||
from .utils import run_command
|
||||
|
||||
|
@ -73,6 +74,17 @@ def configure_llama_distribution(dist: Distribution, conda_env: str):
|
|||
f"Please re-run configure by activating the `{conda_env}` conda environment"
|
||||
)
|
||||
|
||||
existing_config = None
|
||||
config_path = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml"
|
||||
if config_path.exists():
|
||||
cprint(
|
||||
f"Configuration already exists for {dist.name}. Will overwrite...",
|
||||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
with open(config_path, "r") as fp:
|
||||
existing_config = yaml.safe_load(fp)
|
||||
|
||||
adapter_configs = {}
|
||||
for api_surface, adapter in dist.adapters.items():
|
||||
if isinstance(adapter, PassthroughApiAdapter):
|
||||
|
@ -82,8 +94,14 @@ def configure_llama_distribution(dist: Distribution, conda_env: str):
|
|||
f"Configuring API surface: {api_surface.value}", "white", attrs=["bold"]
|
||||
)
|
||||
config_type = instantiate_class_type(adapter.config_class)
|
||||
# TODO: when we are re-configuring, we should read existing values
|
||||
config = prompt_for_config(config_type)
|
||||
config = prompt_for_config(
|
||||
config_type,
|
||||
(
|
||||
config_type(**existing_config["adapters"][api_surface.value])
|
||||
if existing_config
|
||||
else None
|
||||
),
|
||||
)
|
||||
adapter_configs[api_surface.value] = config.dict()
|
||||
|
||||
dist_config = {
|
||||
|
@ -91,11 +109,11 @@ def configure_llama_distribution(dist: Distribution, conda_env: str):
|
|||
"conda_env": conda_env,
|
||||
}
|
||||
|
||||
yaml_output_path = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml"
|
||||
with open(yaml_output_path, "w") as fp:
|
||||
with open(config_path, "w") as fp:
|
||||
dist_config = json.loads(json.dumps(dist_config, cls=EnumEncoder))
|
||||
fp.write(yaml.dump(dist_config, sort_keys=False))
|
||||
|
||||
print(f"YAML configuration has been written to {yaml_output_path}")
|
||||
print(f"YAML configuration has been written to {config_path}")
|
||||
|
||||
|
||||
def instantiate_class_type(fully_qualified_name):
|
||||
|
@ -121,7 +139,9 @@ def get_non_none_type(field_type):
|
|||
return next(arg for arg in get_args(field_type) if arg is not type(None))
|
||||
|
||||
|
||||
def prompt_for_config(config_type: type[BaseModel]) -> BaseModel:
|
||||
def prompt_for_config(
|
||||
config_type: type[BaseModel], existing_config: Optional[BaseModel] = None
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
|
||||
|
||||
|
@ -135,9 +155,16 @@ def prompt_for_config(config_type: type[BaseModel]) -> BaseModel:
|
|||
|
||||
for field_name, field in config_type.__fields__.items():
|
||||
field_type = field.annotation
|
||||
default_value = (
|
||||
field.default if not isinstance(field.default, type(Ellipsis)) else None
|
||||
|
||||
existing_value = (
|
||||
getattr(existing_config, field_name) if existing_config else None
|
||||
)
|
||||
if existing_value:
|
||||
default_value = existing_value
|
||||
else:
|
||||
default_value = (
|
||||
field.default if not isinstance(field.default, type(Ellipsis)) else None
|
||||
)
|
||||
is_required = field.required
|
||||
|
||||
# Skip fields with Literal type
|
||||
|
@ -167,7 +194,14 @@ def prompt_for_config(config_type: type[BaseModel]) -> BaseModel:
|
|||
if discriminator_value in type_map:
|
||||
chosen_type = type_map[discriminator_value]
|
||||
print(f"\nConfiguring {chosen_type.__name__}:")
|
||||
sub_config = prompt_for_config(chosen_type)
|
||||
|
||||
if existing_value and (
|
||||
getattr(existing_value, discriminator)
|
||||
!= discriminator_value
|
||||
):
|
||||
existing_value = None
|
||||
|
||||
sub_config = prompt_for_config(chosen_type, existing_value)
|
||||
config_data[field_name] = sub_config
|
||||
# Set the discriminator field in the sub-config
|
||||
setattr(sub_config, discriminator, discriminator_value)
|
||||
|
@ -178,10 +212,15 @@ def prompt_for_config(config_type: type[BaseModel]) -> BaseModel:
|
|||
|
||||
if inspect.isclass(field_type) and issubclass(field_type, BaseModel):
|
||||
print(f"\nEntering sub-configuration for {field_name}:")
|
||||
config_data[field_name] = prompt_for_config(field_type)
|
||||
config_data[field_name] = prompt_for_config(
|
||||
field_type,
|
||||
existing_value,
|
||||
)
|
||||
else:
|
||||
prompt = f"Enter value for {field_name}"
|
||||
if default_value is not None:
|
||||
if existing_value is not None:
|
||||
prompt += f" (existing: {existing_value})"
|
||||
elif default_value is not None:
|
||||
prompt += f" (default: {default_value})"
|
||||
if is_optional(field_type):
|
||||
prompt += " (optional)"
|
||||
|
@ -195,10 +234,7 @@ def prompt_for_config(config_type: type[BaseModel]) -> BaseModel:
|
|||
if default_value is not None:
|
||||
config_data[field_name] = default_value
|
||||
break
|
||||
elif is_optional(field_type):
|
||||
config_data[field_name] = None
|
||||
break
|
||||
elif not is_required:
|
||||
elif is_optional(field_type) or not is_required:
|
||||
config_data[field_name] = None
|
||||
break
|
||||
else:
|
||||
|
|
|
@ -7,21 +7,13 @@
|
|||
import argparse
|
||||
import json
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.cli.table import print_table
|
||||
|
||||
|
||||
class EnumEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, Enum):
|
||||
return obj.value
|
||||
return super().default(obj)
|
||||
from llama_toolchain.utils import EnumEncoder
|
||||
|
||||
|
||||
class ModelDescribe(Subcommand):
|
||||
|
|
|
@ -5,7 +5,9 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import getpass
|
||||
import json
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
@ -65,3 +67,10 @@ def parse_config(config_dir: str, config_path: Optional[str] = None) -> str:
|
|||
print("------------------------")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class EnumEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, Enum):
|
||||
return obj.value
|
||||
return super().default(obj)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue