add DistributionConfig, fix a bug in model download

This commit is contained in:
Ashwin Bharambe 2024-08-06 19:24:52 -07:00
parent ade574a0ef
commit 9e1ca4eeb1
4 changed files with 35 additions and 39 deletions

View file

@ -8,8 +8,6 @@ import argparse
import json import json
import shlex import shlex
from typing import Any, Dict
import yaml import yaml
from termcolor import cprint from termcolor import cprint
@ -40,6 +38,7 @@ class DistributionConfigure(Subcommand):
) )
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None: def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.datatypes import DistributionConfig
from llama_toolchain.distribution.registry import resolve_distribution_spec from llama_toolchain.distribution.registry import resolve_distribution_spec
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml" config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
@ -51,16 +50,16 @@ class DistributionConfigure(Subcommand):
# we need to find the spec from the name # we need to find the spec from the name
with open(config_file, "r") as f: with open(config_file, "r") as f:
config = yaml.safe_load(f) config = DistributionConfig(**yaml.safe_load(f))
dist = resolve_distribution_spec(config["spec"]) dist = resolve_distribution_spec(config.spec)
if dist is None: if dist is None:
raise ValueError(f"Could not find any registered spec `{config['spec']}`") raise ValueError(f"Could not find any registered spec `{config.spec}`")
configure_llama_distribution(dist, config) configure_llama_distribution(dist, config)
def configure_llama_distribution(dist: "Distribution", config: Dict[str, Any]): def configure_llama_distribution(dist: "Distribution", config: "DistributionConfig"):
from llama_toolchain.common.exec import run_command from llama_toolchain.common.exec import run_command
from llama_toolchain.common.prompt_for_config import prompt_for_config from llama_toolchain.common.prompt_for_config import prompt_for_config
from llama_toolchain.common.serialize import EnumEncoder from llama_toolchain.common.serialize import EnumEncoder
@ -68,47 +67,39 @@ def configure_llama_distribution(dist: "Distribution", config: Dict[str, Any]):
python_exe = run_command(shlex.split("which python")) python_exe = run_command(shlex.split("which python"))
# simple check # simple check
conda_env = config["conda_env"] conda_env = config.conda_env
if conda_env not in python_exe: if conda_env not in python_exe:
raise ValueError( raise ValueError(
f"Please re-run configure by activating the `{conda_env}` conda environment" f"Please re-run configure by activating the `{conda_env}` conda environment"
) )
existing_config = config if config.providers:
if "providers" in existing_config:
cprint( cprint(
f"Configuration already exists for {config['name']}. Will overwrite...", f"Configuration already exists for {config.name}. Will overwrite...",
"yellow", "yellow",
attrs=["bold"], attrs=["bold"],
) )
provider_configs = {}
for api, provider_spec in dist.provider_specs.items(): for api, provider_spec in dist.provider_specs.items():
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"]) cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
config = prompt_for_config( provider_config = prompt_for_config(
config_type, config_type,
( (
config_type(**existing_config["providers"][api.value]) config_type(**config.providers[api.value])
if existing_config if api.value in config.providers
and "providers" in existing_config
and api.value in existing_config["providers"]
else None else None
), ),
) )
provider_configs[api.value] = {
config.providers[api.value] = {
"provider_id": provider_spec.provider_id, "provider_id": provider_spec.provider_id,
**config.dict(), **provider_config.dict(),
} }
dist_config = { config_path = DISTRIBS_BASE_DIR / config.name / "config.yaml"
**existing_config,
"providers": provider_configs,
}
config_path = DISTRIBS_BASE_DIR / existing_config["name"] / "config.yaml"
with open(config_path, "w") as fp: with open(config_path, "w") as fp:
dist_config = json.loads(json.dumps(dist_config, cls=EnumEncoder)) dist_config = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
fp.write(yaml.dump(dist_config, sort_keys=False)) fp.write(yaml.dump(dist_config, sort_keys=False))
print(f"YAML configuration has been written to {config_path}") print(f"YAML configuration has been written to {config_path}")

View file

@ -6,7 +6,6 @@
import argparse import argparse
import os import os
import shlex
import textwrap import textwrap
import pkg_resources import pkg_resources
@ -56,6 +55,7 @@ class DistributionInstall(Subcommand):
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None: def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.datatypes import DistributionConfig
from llama_toolchain.distribution.distribution import distribution_dependencies from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import resolve_distribution_spec from llama_toolchain.distribution.registry import resolve_distribution_spec
@ -86,12 +86,12 @@ class DistributionInstall(Subcommand):
config_file = distrib_dir / "config.yaml" config_file = distrib_dir / "config.yaml"
with open(config_file, "w") as f: with open(config_file, "w") as f:
c = { c = DistributionConfig(
"conda_env": conda_env, spec=dist.spec_id,
"spec": dist.spec_id, name=args.name,
"name": args.name, conda_env=conda_env,
} )
f.write(yaml.dump(c)) f.write(yaml.dump(c.dict(), sort_keys=False))
cprint( cprint(
f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!", f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!",
@ -105,4 +105,5 @@ class DistributionInstall(Subcommand):
conda deactivate && conda activate {conda_env} conda deactivate && conda activate {conda_env}
llama distribution configure --name {args.name} llama distribution configure --name {args.name}
""" """
)) )
)

View file

@ -16,7 +16,6 @@ import httpx
from termcolor import cprint from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DEFAULT_CHECKPOINT_DIR
class Download(Subcommand): class Download(Subcommand):
@ -109,9 +108,10 @@ safetensors files to avoid downloading duplicate weights.
def _meta_download(self, model: "Model", meta_url: str): def _meta_download(self, model: "Model", meta_url: str):
from llama_models.sku_list import llama_meta_net_info from llama_models.sku_list import llama_meta_net_info
from llama_toolchain.common.model_utils import model_local_dir from llama_toolchain.common.model_utils import model_local_dir
output_dir = model_local_dir(model) output_dir = Path(model_local_dir(model))
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
info = llama_meta_net_info(model) info = llama_meta_net_info(model)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from strong_typing.schema import json_schema_type from strong_typing.schema import json_schema_type
@ -91,9 +91,13 @@ class DistributionSpec(BaseModel):
@json_schema_type @json_schema_type
class InstalledDistribution(BaseModel): class DistributionConfig(BaseModel):
"""References to a installed / configured DistributionSpec""" """References to a installed / configured DistributionSpec"""
name: str name: str
spec_id: str spec: str
# This is the class which represents the configs written by `configure` conda_env: str
providers: Dict[str, Any] = Field(
default_factory=dict,
description="Provider configurations for each of the APIs provided by this distribution",
)