From 9e1ca4eeb1497b1133f8ef4c98c01faa31e65f63 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 6 Aug 2024 19:24:52 -0700 Subject: [PATCH] add DistributionConfig, fix a bug in model download --- llama_toolchain/cli/distribution/configure.py | 41 ++++++++----------- llama_toolchain/cli/distribution/install.py | 17 ++++---- llama_toolchain/cli/download.py | 4 +- llama_toolchain/distribution/datatypes.py | 12 ++++-- 4 files changed, 35 insertions(+), 39 deletions(-) diff --git a/llama_toolchain/cli/distribution/configure.py b/llama_toolchain/cli/distribution/configure.py index 235d48da1..e90c875c5 100644 --- a/llama_toolchain/cli/distribution/configure.py +++ b/llama_toolchain/cli/distribution/configure.py @@ -8,8 +8,6 @@ import argparse import json import shlex -from typing import Any, Dict - import yaml from termcolor import cprint @@ -40,6 +38,7 @@ class DistributionConfigure(Subcommand): ) def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None: + from llama_toolchain.distribution.datatypes import DistributionConfig from llama_toolchain.distribution.registry import resolve_distribution_spec config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml" @@ -51,16 +50,16 @@ class DistributionConfigure(Subcommand): # we need to find the spec from the name with open(config_file, "r") as f: - config = yaml.safe_load(f) + config = DistributionConfig(**yaml.safe_load(f)) - dist = resolve_distribution_spec(config["spec"]) + dist = resolve_distribution_spec(config.spec) if dist is None: - raise ValueError(f"Could not find any registered spec `{config['spec']}`") + raise ValueError(f"Could not find any registered spec `{config.spec}`") configure_llama_distribution(dist, config) -def configure_llama_distribution(dist: "Distribution", config: Dict[str, Any]): +def configure_llama_distribution(dist: "Distribution", config: "DistributionConfig"): from llama_toolchain.common.exec import run_command from llama_toolchain.common.prompt_for_config import prompt_for_config from llama_toolchain.common.serialize import EnumEncoder @@ -68,47 +67,39 @@ def configure_llama_distribution(dist: "Distribution", config: Dict[str, Any]): python_exe = run_command(shlex.split("which python")) # simple check - conda_env = config["conda_env"] + conda_env = config.conda_env if conda_env not in python_exe: raise ValueError( f"Please re-run configure by activating the `{conda_env}` conda environment" ) - existing_config = config - if "providers" in existing_config: + if config.providers: cprint( - f"Configuration already exists for {config['name']}. Will overwrite...", + f"Configuration already exists for {config.name}. Will overwrite...", "yellow", attrs=["bold"], ) - provider_configs = {} for api, provider_spec in dist.provider_specs.items(): cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"]) config_type = instantiate_class_type(provider_spec.config_class) - config = prompt_for_config( + provider_config = prompt_for_config( config_type, ( - config_type(**existing_config["providers"][api.value]) - if existing_config - and "providers" in existing_config - and api.value in existing_config["providers"] + config_type(**config.providers[api.value]) + if api.value in config.providers else None ), ) - provider_configs[api.value] = { + + config.providers[api.value] = { "provider_id": provider_spec.provider_id, - **config.dict(), + **provider_config.dict(), } - dist_config = { - **existing_config, - "providers": provider_configs, - } - - config_path = DISTRIBS_BASE_DIR / existing_config["name"] / "config.yaml" + config_path = DISTRIBS_BASE_DIR / config.name / "config.yaml" with open(config_path, "w") as fp: - dist_config = json.loads(json.dumps(dist_config, cls=EnumEncoder)) + dist_config = json.loads(json.dumps(config.dict(), cls=EnumEncoder)) fp.write(yaml.dump(dist_config, sort_keys=False)) print(f"YAML configuration has been written to {config_path}") diff --git a/llama_toolchain/cli/distribution/install.py b/llama_toolchain/cli/distribution/install.py index e30e05268..8ce3c04b9 100644 --- a/llama_toolchain/cli/distribution/install.py +++ b/llama_toolchain/cli/distribution/install.py @@ -6,7 +6,6 @@ import argparse import os -import shlex import textwrap import pkg_resources @@ -56,6 +55,7 @@ class DistributionInstall(Subcommand): def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None: from llama_toolchain.common.exec import run_with_pty + from llama_toolchain.distribution.datatypes import DistributionConfig from llama_toolchain.distribution.distribution import distribution_dependencies from llama_toolchain.distribution.registry import resolve_distribution_spec @@ -86,12 +86,12 @@ class DistributionInstall(Subcommand): config_file = distrib_dir / "config.yaml" with open(config_file, "w") as f: - c = { - "conda_env": conda_env, - "spec": dist.spec_id, - "name": args.name, - } - f.write(yaml.dump(c)) + c = DistributionConfig( + spec=dist.spec_id, + name=args.name, + conda_env=conda_env, + ) + f.write(yaml.dump(c.dict(), sort_keys=False)) cprint( f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!", @@ -105,4 +105,5 @@ class DistributionInstall(Subcommand): conda deactivate && conda activate {conda_env} llama distribution configure --name {args.name} """ - )) + ) + ) diff --git a/llama_toolchain/cli/download.py b/llama_toolchain/cli/download.py index b268d3b8d..b8ade9b14 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_toolchain/cli/download.py @@ -16,7 +16,6 @@ import httpx from termcolor import cprint from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.common.config_dirs import DEFAULT_CHECKPOINT_DIR class Download(Subcommand): @@ -109,9 +108,10 @@ safetensors files to avoid downloading duplicate weights. def _meta_download(self, model: "Model", meta_url: str): from llama_models.sku_list import llama_meta_net_info + from llama_toolchain.common.model_utils import model_local_dir - output_dir = model_local_dir(model) + output_dir = Path(model_local_dir(model)) os.makedirs(output_dir, exist_ok=True) info = llama_meta_net_info(model) diff --git a/llama_toolchain/distribution/datatypes.py b/llama_toolchain/distribution/datatypes.py index 00aa07682..b5c0d8e1f 100644 --- a/llama_toolchain/distribution/datatypes.py +++ b/llama_toolchain/distribution/datatypes.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field from strong_typing.schema import json_schema_type @@ -91,9 +91,13 @@ class DistributionSpec(BaseModel): @json_schema_type -class InstalledDistribution(BaseModel): +class DistributionConfig(BaseModel): """References to a installed / configured DistributionSpec""" name: str - spec_id: str - # This is the class which represents the configs written by `configure` + spec: str + conda_env: str + providers: Dict[str, Any] = Field( + default_factory=dict, + description="Provider configurations for each of the APIs provided by this distribution", + )