mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
add DistributionConfig, fix a bug in model download
This commit is contained in:
parent
ade574a0ef
commit
9e1ca4eeb1
4 changed files with 35 additions and 39 deletions
|
@ -8,8 +8,6 @@ import argparse
|
||||||
import json
|
import json
|
||||||
import shlex
|
import shlex
|
||||||
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
@ -40,6 +38,7 @@ class DistributionConfigure(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
|
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
from llama_toolchain.distribution.datatypes import DistributionConfig
|
||||||
from llama_toolchain.distribution.registry import resolve_distribution_spec
|
from llama_toolchain.distribution.registry import resolve_distribution_spec
|
||||||
|
|
||||||
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
|
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
|
||||||
|
@ -51,16 +50,16 @@ class DistributionConfigure(Subcommand):
|
||||||
|
|
||||||
# we need to find the spec from the name
|
# we need to find the spec from the name
|
||||||
with open(config_file, "r") as f:
|
with open(config_file, "r") as f:
|
||||||
config = yaml.safe_load(f)
|
config = DistributionConfig(**yaml.safe_load(f))
|
||||||
|
|
||||||
dist = resolve_distribution_spec(config["spec"])
|
dist = resolve_distribution_spec(config.spec)
|
||||||
if dist is None:
|
if dist is None:
|
||||||
raise ValueError(f"Could not find any registered spec `{config['spec']}`")
|
raise ValueError(f"Could not find any registered spec `{config.spec}`")
|
||||||
|
|
||||||
configure_llama_distribution(dist, config)
|
configure_llama_distribution(dist, config)
|
||||||
|
|
||||||
|
|
||||||
def configure_llama_distribution(dist: "Distribution", config: Dict[str, Any]):
|
def configure_llama_distribution(dist: "Distribution", config: "DistributionConfig"):
|
||||||
from llama_toolchain.common.exec import run_command
|
from llama_toolchain.common.exec import run_command
|
||||||
from llama_toolchain.common.prompt_for_config import prompt_for_config
|
from llama_toolchain.common.prompt_for_config import prompt_for_config
|
||||||
from llama_toolchain.common.serialize import EnumEncoder
|
from llama_toolchain.common.serialize import EnumEncoder
|
||||||
|
@ -68,47 +67,39 @@ def configure_llama_distribution(dist: "Distribution", config: Dict[str, Any]):
|
||||||
|
|
||||||
python_exe = run_command(shlex.split("which python"))
|
python_exe = run_command(shlex.split("which python"))
|
||||||
# simple check
|
# simple check
|
||||||
conda_env = config["conda_env"]
|
conda_env = config.conda_env
|
||||||
if conda_env not in python_exe:
|
if conda_env not in python_exe:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Please re-run configure by activating the `{conda_env}` conda environment"
|
f"Please re-run configure by activating the `{conda_env}` conda environment"
|
||||||
)
|
)
|
||||||
|
|
||||||
existing_config = config
|
if config.providers:
|
||||||
if "providers" in existing_config:
|
|
||||||
cprint(
|
cprint(
|
||||||
f"Configuration already exists for {config['name']}. Will overwrite...",
|
f"Configuration already exists for {config.name}. Will overwrite...",
|
||||||
"yellow",
|
"yellow",
|
||||||
attrs=["bold"],
|
attrs=["bold"],
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_configs = {}
|
|
||||||
for api, provider_spec in dist.provider_specs.items():
|
for api, provider_spec in dist.provider_specs.items():
|
||||||
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
|
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
config = prompt_for_config(
|
provider_config = prompt_for_config(
|
||||||
config_type,
|
config_type,
|
||||||
(
|
(
|
||||||
config_type(**existing_config["providers"][api.value])
|
config_type(**config.providers[api.value])
|
||||||
if existing_config
|
if api.value in config.providers
|
||||||
and "providers" in existing_config
|
|
||||||
and api.value in existing_config["providers"]
|
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
provider_configs[api.value] = {
|
|
||||||
|
config.providers[api.value] = {
|
||||||
"provider_id": provider_spec.provider_id,
|
"provider_id": provider_spec.provider_id,
|
||||||
**config.dict(),
|
**provider_config.dict(),
|
||||||
}
|
}
|
||||||
|
|
||||||
dist_config = {
|
config_path = DISTRIBS_BASE_DIR / config.name / "config.yaml"
|
||||||
**existing_config,
|
|
||||||
"providers": provider_configs,
|
|
||||||
}
|
|
||||||
|
|
||||||
config_path = DISTRIBS_BASE_DIR / existing_config["name"] / "config.yaml"
|
|
||||||
with open(config_path, "w") as fp:
|
with open(config_path, "w") as fp:
|
||||||
dist_config = json.loads(json.dumps(dist_config, cls=EnumEncoder))
|
dist_config = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
|
||||||
fp.write(yaml.dump(dist_config, sort_keys=False))
|
fp.write(yaml.dump(dist_config, sort_keys=False))
|
||||||
|
|
||||||
print(f"YAML configuration has been written to {config_path}")
|
print(f"YAML configuration has been written to {config_path}")
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import shlex
|
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
@ -56,6 +55,7 @@ class DistributionInstall(Subcommand):
|
||||||
|
|
||||||
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
|
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from llama_toolchain.common.exec import run_with_pty
|
from llama_toolchain.common.exec import run_with_pty
|
||||||
|
from llama_toolchain.distribution.datatypes import DistributionConfig
|
||||||
from llama_toolchain.distribution.distribution import distribution_dependencies
|
from llama_toolchain.distribution.distribution import distribution_dependencies
|
||||||
from llama_toolchain.distribution.registry import resolve_distribution_spec
|
from llama_toolchain.distribution.registry import resolve_distribution_spec
|
||||||
|
|
||||||
|
@ -86,12 +86,12 @@ class DistributionInstall(Subcommand):
|
||||||
|
|
||||||
config_file = distrib_dir / "config.yaml"
|
config_file = distrib_dir / "config.yaml"
|
||||||
with open(config_file, "w") as f:
|
with open(config_file, "w") as f:
|
||||||
c = {
|
c = DistributionConfig(
|
||||||
"conda_env": conda_env,
|
spec=dist.spec_id,
|
||||||
"spec": dist.spec_id,
|
name=args.name,
|
||||||
"name": args.name,
|
conda_env=conda_env,
|
||||||
}
|
)
|
||||||
f.write(yaml.dump(c))
|
f.write(yaml.dump(c.dict(), sort_keys=False))
|
||||||
|
|
||||||
cprint(
|
cprint(
|
||||||
f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!",
|
f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!",
|
||||||
|
@ -105,4 +105,5 @@ class DistributionInstall(Subcommand):
|
||||||
conda deactivate && conda activate {conda_env}
|
conda deactivate && conda activate {conda_env}
|
||||||
llama distribution configure --name {args.name}
|
llama distribution configure --name {args.name}
|
||||||
"""
|
"""
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
|
@ -16,7 +16,6 @@ import httpx
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
from llama_toolchain.common.config_dirs import DEFAULT_CHECKPOINT_DIR
|
|
||||||
|
|
||||||
|
|
||||||
class Download(Subcommand):
|
class Download(Subcommand):
|
||||||
|
@ -109,9 +108,10 @@ safetensors files to avoid downloading duplicate weights.
|
||||||
|
|
||||||
def _meta_download(self, model: "Model", meta_url: str):
|
def _meta_download(self, model: "Model", meta_url: str):
|
||||||
from llama_models.sku_list import llama_meta_net_info
|
from llama_models.sku_list import llama_meta_net_info
|
||||||
|
|
||||||
from llama_toolchain.common.model_utils import model_local_dir
|
from llama_toolchain.common.model_utils import model_local_dir
|
||||||
|
|
||||||
output_dir = model_local_dir(model)
|
output_dir = Path(model_local_dir(model))
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
info = llama_meta_net_info(model)
|
info = llama_meta_net_info(model)
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from strong_typing.schema import json_schema_type
|
from strong_typing.schema import json_schema_type
|
||||||
|
@ -91,9 +91,13 @@ class DistributionSpec(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class InstalledDistribution(BaseModel):
|
class DistributionConfig(BaseModel):
|
||||||
"""References to a installed / configured DistributionSpec"""
|
"""References to a installed / configured DistributionSpec"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
spec_id: str
|
spec: str
|
||||||
# This is the class which represents the configs written by `configure`
|
conda_env: str
|
||||||
|
providers: Dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Provider configurations for each of the APIs provided by this distribution",
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue