From 2076d2b6dbd1b2b2fa7edcd467bd3e1df6ea079e Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 27 Aug 2024 21:40:43 -0700 Subject: [PATCH] api build works for conda now --- llama_toolchain/cli/api/build.py | 19 +++++++++++++------ llama_toolchain/cli/api/configure.py | 5 ++--- llama_toolchain/common/prompt_for_config.py | 2 -- llama_toolchain/common/serialize.py | 3 +++ .../distribution/build_conda_env.sh | 14 ++++++++++---- 5 files changed, 28 insertions(+), 15 deletions(-) diff --git a/llama_toolchain/cli/api/build.py b/llama_toolchain/cli/api/build.py index 0b589cdcb..4e8f7cfba 100644 --- a/llama_toolchain/cli/api/build.py +++ b/llama_toolchain/cli/api/build.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import argparse +import json import os import random import string @@ -137,6 +138,7 @@ class ApiBuild(Subcommand): def _run_api_build_command(self, args: argparse.Namespace) -> None: from llama_toolchain.common.exec import run_with_pty from llama_toolchain.distribution.distribution import api_providers + from llama_toolchain.common.serialize import EnumEncoder os.makedirs(BUILDS_BASE_DIR, exist_ok=True) all_providers = api_providers() @@ -174,7 +176,7 @@ class ApiBuild(Subcommand): } with open(package_file, "w") as f: c = PackageConfig( - built_at=datetime.now(), + built_at=str(datetime.now()), package_name=package_name, docker_image=( package_name if args.type == BuildType.container.value else None @@ -184,7 +186,8 @@ class ApiBuild(Subcommand): ), providers=stub_config, ) - f.write(yaml.dump(c.dict(), sort_keys=False)) + to_write = json.loads(json.dumps(c.dict(), cls=EnumEncoder)) + f.write(yaml.dump(to_write, sort_keys=False)) if args.type == BuildType.container.value: script = pkg_resources.resource_filename( @@ -209,10 +212,14 @@ class ApiBuild(Subcommand): ] return_code = run_with_pty(args) - assert return_code == 0, cprint( - f"Failed to build target {package_name}", color="red" - ) + if return_code != 0: + cprint( + f"Failed to build target {package_name} with return code {return_code}", + color="red", + ) + return + cprint( - f"Target `{target_name}` built with configuration at {str(package_file)}", + f"Target `{package_name}` built with configuration at {str(package_file)}", color="green", ) diff --git a/llama_toolchain/cli/api/configure.py b/llama_toolchain/cli/api/configure.py index c2cb8b16f..a3582f02e 100644 --- a/llama_toolchain/cli/api/configure.py +++ b/llama_toolchain/cli/api/configure.py @@ -81,13 +81,12 @@ def configure_llama_provider(config_file: Path) -> None: provider_spec = providers[provider_id] cprint(f"Configuring API surface: {api}", "white", attrs=["bold"]) config_type = instantiate_class_type(provider_spec.config_class) - print(f"Config type: {config_type}") provider_config = prompt_for_config( config_type, ) print("") - provider_configs[api.value] = { + provider_configs[api] = { "provider_id": provider_id, **provider_config.dict(), } @@ -97,4 +96,4 @@ def configure_llama_provider(config_file: Path) -> None: to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder)) fp.write(yaml.dump(to_write, sort_keys=False)) - print(f"YAML configuration has been written to {config_path}") + print(f"YAML configuration has been written to {config_file}") diff --git a/llama_toolchain/common/prompt_for_config.py b/llama_toolchain/common/prompt_for_config.py index c87716750..6c53477d8 100644 --- a/llama_toolchain/common/prompt_for_config.py +++ b/llama_toolchain/common/prompt_for_config.py @@ -71,7 +71,6 @@ def prompt_for_config( """ config_data = {} - print(f"Configuring {config_type.__name__}:") for field_name, field in config_type.__fields__.items(): field_type = field.annotation @@ -86,7 +85,6 @@ def prompt_for_config( if not isinstance(field.default, PydanticUndefinedType) else None ) - print(f" {field_name}: {field_type} (default: {default_value})") is_required = field.is_required # Skip fields with Literal type diff --git a/llama_toolchain/common/serialize.py b/llama_toolchain/common/serialize.py index 813851fe9..667902beb 100644 --- a/llama_toolchain/common/serialize.py +++ b/llama_toolchain/common/serialize.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import json +from datetime import datetime from enum import Enum @@ -12,4 +13,6 @@ class EnumEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, Enum): return obj.value + elif isinstance(obj, datetime): + return obj.isoformat() return super().default(obj) diff --git a/llama_toolchain/distribution/build_conda_env.sh b/llama_toolchain/distribution/build_conda_env.sh index 0b45edf09..ecdeaba1b 100755 --- a/llama_toolchain/distribution/build_conda_env.sh +++ b/llama_toolchain/distribution/build_conda_env.sh @@ -10,7 +10,13 @@ LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-} LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-} TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} -echo "llama-toolchain-dir=$LLAMA_TOOLCHAIN_DIR" +if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then + echo "Using llama-toolchain-dir=$LLAMA_TOOLCHAIN_DIR" +fi +if [ -n "$LLAMA_MODELS_DIR" ]; then + echo "Using llama-models-dir=$LLAMA_MODELS_DIR" +fi + set -euo pipefail if [ "$#" -ne 3 ]; then @@ -82,9 +88,9 @@ ensure_conda_env_python310() { fi echo "Installing from LLAMA_TOOLCHAIN_DIR: $LLAMA_TOOLCHAIN_DIR" - pip install -e "$LLAMA_TOOLCHAIN_DIR" + pip install --no-cache-dir -e "$LLAMA_TOOLCHAIN_DIR" else - pip install llama-toolchain + pip install --no-cache-dir llama-toolchain fi if [ -n "$LLAMA_MODELS_DIR" ]; then @@ -95,7 +101,7 @@ ensure_conda_env_python310() { echo "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR" pip uninstall -y llama-models - pip install -e "$LLAMA_MODELS_DIR" + pip install --no-cache-dir -e "$LLAMA_MODELS_DIR" fi # Install pip dependencies