From ade574a0ef02c1ade448e3587ca149488d6a0011 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Tue, 6 Aug 2024 18:56:22 -0700 Subject: [PATCH] minor fixes --- llama_toolchain/cli/distribution/configure.py | 2 +- llama_toolchain/cli/model/describe.py | 4 ++-- llama_toolchain/cli/table.py | 2 +- llama_toolchain/inference/client.py | 1 - 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/llama_toolchain/cli/distribution/configure.py b/llama_toolchain/cli/distribution/configure.py index fa1a42dc0..235d48da1 100644 --- a/llama_toolchain/cli/distribution/configure.py +++ b/llama_toolchain/cli/distribution/configure.py @@ -102,8 +102,8 @@ def configure_llama_distribution(dist: "Distribution", config: Dict[str, Any]): } dist_config = { - "providers": provider_configs, **existing_config, + "providers": provider_configs, } config_path = DISTRIBS_BASE_DIR / existing_config["name"] / "config.yaml" diff --git a/llama_toolchain/cli/model/describe.py b/llama_toolchain/cli/model/describe.py index a24fe15f7..e0fb44a96 100644 --- a/llama_toolchain/cli/model/describe.py +++ b/llama_toolchain/cli/model/describe.py @@ -49,9 +49,9 @@ class ModelDescribe(Subcommand): rows = [ ( colored("Model", "white", attrs=["bold"]), - colored(model.sku.value, "white", attrs=["bold"]), + colored(model.descriptor(), "white", attrs=["bold"]), ), - ("HuggingFace ID", model.huggingface_id or ""), + ("HuggingFace ID", model.huggingface_repo or ""), ("Description", model.description_markdown), ("Context Length", f"{model.max_seq_length // 1024}K tokens"), ("Weights format", model.quantization_format.value), diff --git a/llama_toolchain/cli/table.py b/llama_toolchain/cli/table.py index b63ae3467..3ee7eea13 100644 --- a/llama_toolchain/cli/table.py +++ b/llama_toolchain/cli/table.py @@ -45,7 +45,7 @@ def format_row(row, col_widths): def print_table(rows, headers=None, separate_rows: bool = False): def itemlen(item): - return len(strip_ansi_colors(item)) + return max([len(line) for line in strip_ansi_colors(item).split("\n")]) rows = [[x or "" for x in row] for row in rows] if not headers: diff --git a/llama_toolchain/inference/client.py b/llama_toolchain/inference/client.py index 36ee6225a..aa84f906d 100644 --- a/llama_toolchain/inference/client.py +++ b/llama_toolchain/inference/client.py @@ -50,7 +50,6 @@ class InferenceClient(Inference): headers={"Content-Type": "application/json"}, timeout=20, ) as response: - print("Headers", response.headers) if response.status_code != 200: content = await response.aread() cprint(