add DistributionConfig, fix a bug in model download

This commit is contained in:
Ashwin Bharambe 2024-08-06 19:24:52 -07:00
parent ade574a0ef
commit 9e1ca4eeb1
4 changed files with 35 additions and 39 deletions

View file

@ -8,8 +8,6 @@ import argparse
import json
import shlex
from typing import Any, Dict
import yaml
from termcolor import cprint
@ -40,6 +38,7 @@ class DistributionConfigure(Subcommand):
)
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.datatypes import DistributionConfig
from llama_toolchain.distribution.registry import resolve_distribution_spec
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
@ -51,16 +50,16 @@ class DistributionConfigure(Subcommand):
# we need to find the spec from the name
with open(config_file, "r") as f:
config = yaml.safe_load(f)
config = DistributionConfig(**yaml.safe_load(f))
dist = resolve_distribution_spec(config["spec"])
dist = resolve_distribution_spec(config.spec)
if dist is None:
raise ValueError(f"Could not find any registered spec `{config['spec']}`")
raise ValueError(f"Could not find any registered spec `{config.spec}`")
configure_llama_distribution(dist, config)
def configure_llama_distribution(dist: "Distribution", config: Dict[str, Any]):
def configure_llama_distribution(dist: "Distribution", config: "DistributionConfig"):
from llama_toolchain.common.exec import run_command
from llama_toolchain.common.prompt_for_config import prompt_for_config
from llama_toolchain.common.serialize import EnumEncoder
@ -68,47 +67,39 @@ def configure_llama_distribution(dist: "Distribution", config: Dict[str, Any]):
python_exe = run_command(shlex.split("which python"))
# simple check
conda_env = config["conda_env"]
conda_env = config.conda_env
if conda_env not in python_exe:
raise ValueError(
f"Please re-run configure by activating the `{conda_env}` conda environment"
)
existing_config = config
if "providers" in existing_config:
if config.providers:
cprint(
f"Configuration already exists for {config['name']}. Will overwrite...",
f"Configuration already exists for {config.name}. Will overwrite...",
"yellow",
attrs=["bold"],
)
provider_configs = {}
for api, provider_spec in dist.provider_specs.items():
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
config_type = instantiate_class_type(provider_spec.config_class)
config = prompt_for_config(
provider_config = prompt_for_config(
config_type,
(
config_type(**existing_config["providers"][api.value])
if existing_config
and "providers" in existing_config
and api.value in existing_config["providers"]
config_type(**config.providers[api.value])
if api.value in config.providers
else None
),
)
provider_configs[api.value] = {
config.providers[api.value] = {
"provider_id": provider_spec.provider_id,
**config.dict(),
**provider_config.dict(),
}
dist_config = {
**existing_config,
"providers": provider_configs,
}
config_path = DISTRIBS_BASE_DIR / existing_config["name"] / "config.yaml"
config_path = DISTRIBS_BASE_DIR / config.name / "config.yaml"
with open(config_path, "w") as fp:
dist_config = json.loads(json.dumps(dist_config, cls=EnumEncoder))
dist_config = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
fp.write(yaml.dump(dist_config, sort_keys=False))
print(f"YAML configuration has been written to {config_path}")

View file

@ -6,7 +6,6 @@
import argparse
import os
import shlex
import textwrap
import pkg_resources
@ -56,6 +55,7 @@ class DistributionInstall(Subcommand):
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.datatypes import DistributionConfig
from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import resolve_distribution_spec
@ -86,12 +86,12 @@ class DistributionInstall(Subcommand):
config_file = distrib_dir / "config.yaml"
with open(config_file, "w") as f:
c = {
"conda_env": conda_env,
"spec": dist.spec_id,
"name": args.name,
}
f.write(yaml.dump(c))
c = DistributionConfig(
spec=dist.spec_id,
name=args.name,
conda_env=conda_env,
)
f.write(yaml.dump(c.dict(), sort_keys=False))
cprint(
f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!",
@ -105,4 +105,5 @@ class DistributionInstall(Subcommand):
conda deactivate && conda activate {conda_env}
llama distribution configure --name {args.name}
"""
))
)
)

View file

@ -16,7 +16,6 @@ import httpx
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DEFAULT_CHECKPOINT_DIR
class Download(Subcommand):
@ -109,9 +108,10 @@ safetensors files to avoid downloading duplicate weights.
def _meta_download(self, model: "Model", meta_url: str):
from llama_models.sku_list import llama_meta_net_info
from llama_toolchain.common.model_utils import model_local_dir
output_dir = model_local_dir(model)
output_dir = Path(model_local_dir(model))
os.makedirs(output_dir, exist_ok=True)
info = llama_meta_net_info(model)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from strong_typing.schema import json_schema_type
@ -91,9 +91,13 @@ class DistributionSpec(BaseModel):
@json_schema_type
class InstalledDistribution(BaseModel):
class DistributionConfig(BaseModel):
"""References to a installed / configured DistributionSpec"""
name: str
spec_id: str
# This is the class which represents the configs written by `configure`
spec: str
conda_env: str
providers: Dict[str, Any] = Field(
default_factory=dict,
description="Provider configurations for each of the APIs provided by this distribution",
)