mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
add DistributionConfig, fix a bug in model download
This commit is contained in:
parent
ade574a0ef
commit
9e1ca4eeb1
4 changed files with 35 additions and 39 deletions
|
@ -8,8 +8,6 @@ import argparse
|
|||
import json
|
||||
import shlex
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
|
@ -40,6 +38,7 @@ class DistributionConfigure(Subcommand):
|
|||
)
|
||||
|
||||
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
|
||||
from llama_toolchain.distribution.datatypes import DistributionConfig
|
||||
from llama_toolchain.distribution.registry import resolve_distribution_spec
|
||||
|
||||
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
|
||||
|
@ -51,16 +50,16 @@ class DistributionConfigure(Subcommand):
|
|||
|
||||
# we need to find the spec from the name
|
||||
with open(config_file, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
config = DistributionConfig(**yaml.safe_load(f))
|
||||
|
||||
dist = resolve_distribution_spec(config["spec"])
|
||||
dist = resolve_distribution_spec(config.spec)
|
||||
if dist is None:
|
||||
raise ValueError(f"Could not find any registered spec `{config['spec']}`")
|
||||
raise ValueError(f"Could not find any registered spec `{config.spec}`")
|
||||
|
||||
configure_llama_distribution(dist, config)
|
||||
|
||||
|
||||
def configure_llama_distribution(dist: "Distribution", config: Dict[str, Any]):
|
||||
def configure_llama_distribution(dist: "Distribution", config: "DistributionConfig"):
|
||||
from llama_toolchain.common.exec import run_command
|
||||
from llama_toolchain.common.prompt_for_config import prompt_for_config
|
||||
from llama_toolchain.common.serialize import EnumEncoder
|
||||
|
@ -68,47 +67,39 @@ def configure_llama_distribution(dist: "Distribution", config: Dict[str, Any]):
|
|||
|
||||
python_exe = run_command(shlex.split("which python"))
|
||||
# simple check
|
||||
conda_env = config["conda_env"]
|
||||
conda_env = config.conda_env
|
||||
if conda_env not in python_exe:
|
||||
raise ValueError(
|
||||
f"Please re-run configure by activating the `{conda_env}` conda environment"
|
||||
)
|
||||
|
||||
existing_config = config
|
||||
if "providers" in existing_config:
|
||||
if config.providers:
|
||||
cprint(
|
||||
f"Configuration already exists for {config['name']}. Will overwrite...",
|
||||
f"Configuration already exists for {config.name}. Will overwrite...",
|
||||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
|
||||
provider_configs = {}
|
||||
for api, provider_spec in dist.provider_specs.items():
|
||||
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = prompt_for_config(
|
||||
provider_config = prompt_for_config(
|
||||
config_type,
|
||||
(
|
||||
config_type(**existing_config["providers"][api.value])
|
||||
if existing_config
|
||||
and "providers" in existing_config
|
||||
and api.value in existing_config["providers"]
|
||||
config_type(**config.providers[api.value])
|
||||
if api.value in config.providers
|
||||
else None
|
||||
),
|
||||
)
|
||||
provider_configs[api.value] = {
|
||||
|
||||
config.providers[api.value] = {
|
||||
"provider_id": provider_spec.provider_id,
|
||||
**config.dict(),
|
||||
**provider_config.dict(),
|
||||
}
|
||||
|
||||
dist_config = {
|
||||
**existing_config,
|
||||
"providers": provider_configs,
|
||||
}
|
||||
|
||||
config_path = DISTRIBS_BASE_DIR / existing_config["name"] / "config.yaml"
|
||||
config_path = DISTRIBS_BASE_DIR / config.name / "config.yaml"
|
||||
with open(config_path, "w") as fp:
|
||||
dist_config = json.loads(json.dumps(dist_config, cls=EnumEncoder))
|
||||
dist_config = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
|
||||
fp.write(yaml.dump(dist_config, sort_keys=False))
|
||||
|
||||
print(f"YAML configuration has been written to {config_path}")
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import argparse
|
||||
import os
|
||||
import shlex
|
||||
import textwrap
|
||||
|
||||
import pkg_resources
|
||||
|
@ -56,6 +55,7 @@ class DistributionInstall(Subcommand):
|
|||
|
||||
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
|
||||
from llama_toolchain.common.exec import run_with_pty
|
||||
from llama_toolchain.distribution.datatypes import DistributionConfig
|
||||
from llama_toolchain.distribution.distribution import distribution_dependencies
|
||||
from llama_toolchain.distribution.registry import resolve_distribution_spec
|
||||
|
||||
|
@ -86,12 +86,12 @@ class DistributionInstall(Subcommand):
|
|||
|
||||
config_file = distrib_dir / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
c = {
|
||||
"conda_env": conda_env,
|
||||
"spec": dist.spec_id,
|
||||
"name": args.name,
|
||||
}
|
||||
f.write(yaml.dump(c))
|
||||
c = DistributionConfig(
|
||||
spec=dist.spec_id,
|
||||
name=args.name,
|
||||
conda_env=conda_env,
|
||||
)
|
||||
f.write(yaml.dump(c.dict(), sort_keys=False))
|
||||
|
||||
cprint(
|
||||
f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!",
|
||||
|
@ -105,4 +105,5 @@ class DistributionInstall(Subcommand):
|
|||
conda deactivate && conda activate {conda_env}
|
||||
llama distribution configure --name {args.name}
|
||||
"""
|
||||
))
|
||||
)
|
||||
)
|
||||
|
|
|
@ -16,7 +16,6 @@ import httpx
|
|||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.common.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
|
||||
|
||||
class Download(Subcommand):
|
||||
|
@ -109,9 +108,10 @@ safetensors files to avoid downloading duplicate weights.
|
|||
|
||||
def _meta_download(self, model: "Model", meta_url: str):
|
||||
from llama_models.sku_list import llama_meta_net_info
|
||||
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
|
||||
output_dir = model_local_dir(model)
|
||||
output_dir = Path(model_local_dir(model))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
info = llama_meta_net_info(model)
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from strong_typing.schema import json_schema_type
|
||||
|
@ -91,9 +91,13 @@ class DistributionSpec(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class InstalledDistribution(BaseModel):
|
||||
class DistributionConfig(BaseModel):
|
||||
"""References to a installed / configured DistributionSpec"""
|
||||
|
||||
name: str
|
||||
spec_id: str
|
||||
# This is the class which represents the configs written by `configure`
|
||||
spec: str
|
||||
conda_env: str
|
||||
providers: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Provider configurations for each of the APIs provided by this distribution",
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue