chore: make cprint write to stderr (#2250)

Also do sys.exit(1) in case of errors
This commit is contained in:
raghotham 2025-05-24 23:39:57 -07:00 committed by GitHub
parent c25bd0ad58
commit 5a422e236c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 81 additions and 44 deletions

View file

@ -9,6 +9,7 @@ import asyncio
import json
import os
import shutil
import sys
from dataclasses import dataclass
from datetime import datetime, timezone
from functools import partial
@ -377,14 +378,15 @@ def _meta_download(
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
asyncio.run(downloader.download_all(tasks))
cprint(f"\nSuccessfully downloaded model to {output_dir}", "green")
cprint(f"\nSuccessfully downloaded model to {output_dir}", color="green", file=sys.stderr)
cprint(
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
"white",
file=sys.stderr,
)
cprint(
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
"yellow",
color="yellow",
file=sys.stderr,
)

View file

@ -79,6 +79,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint(
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
color="red",
file=sys.stderr,
)
sys.exit(1)
build_config = available_templates[args.template]
@ -88,6 +89,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint(
f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {args.template}",
color="red",
file=sys.stderr,
)
sys.exit(1)
elif args.providers:
@ -97,6 +99,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint(
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
color="red",
file=sys.stderr,
)
sys.exit(1)
api, provider = api_provider.split("=")
@ -105,6 +108,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint(
f"{api} is not a valid API.",
color="red",
file=sys.stderr,
)
sys.exit(1)
if provider in providers_for_api:
@ -113,6 +117,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint(
f"{provider} is not a valid provider for the {api} API.",
color="red",
file=sys.stderr,
)
sys.exit(1)
distribution_spec = DistributionSpec(
@ -123,6 +128,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint(
f"Please specify a image-type (container | conda | venv) for {args.template}",
color="red",
file=sys.stderr,
)
sys.exit(1)
@ -151,12 +157,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint(
f"No current conda environment detected or specified, will create a new conda environment with the name `llamastack-{name}`",
color="yellow",
file=sys.stderr,
)
image_name = f"llamastack-{name}"
else:
cprint(
f"Using conda environment {image_name}",
color="green",
file=sys.stderr,
)
else:
image_name = f"llamastack-{name}"
@ -169,9 +177,10 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
""",
),
color="green",
file=sys.stderr,
)
print("Tip: use <TAB> to see options for the providers.\n")
cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr)
providers = dict()
for api, providers_for_api in get_provider_registry().items():
@ -213,6 +222,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint(
f"Could not parse config file {args.config}: {e}",
color="red",
file=sys.stderr,
)
sys.exit(1)
@ -239,14 +249,17 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint(
f"Error building stack: {exc}",
color="red",
file=sys.stderr,
)
cprint("Stack trace:", color="red")
cprint("Stack trace:", color="red", file=sys.stderr)
traceback.print_exc()
sys.exit(1)
if run_config is None:
cprint(
"Run config path is empty",
color="red",
file=sys.stderr,
)
sys.exit(1)
@ -304,6 +317,7 @@ def _generate_run_config(
cprint(
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
color="yellow",
file=sys.stderr,
)
# Set config_type to None to avoid UnboundLocalError
config_type = None
@ -331,10 +345,7 @@ def _generate_run_config(
# For non-container builds, the run.yaml is generated at the very end of the build process so it
# makes sense to display this message
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
cprint(
f"You can now run your stack with `llama stack run {run_config_file}`",
color="green",
)
cprint(f"You can now run your stack with `llama stack run {run_config_file}`", color="green", file=sys.stderr)
return run_config_file
@ -372,7 +383,7 @@ def _run_stack_build_command_from_build_config(
# Generate the run.yaml so it can be included in the container image with the proper entrypoint
# Only do this if we're building a container image and we're not using a template
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path:
cprint("Generating run.yaml file", color="green")
cprint("Generating run.yaml file", color="yellow", file=sys.stderr)
run_config_file = _generate_run_config(build_config, build_dir, image_name)
with open(build_file_path, "w") as f:
@ -396,11 +407,13 @@ def _run_stack_build_command_from_build_config(
run_config_file = build_dir / f"{template_name}-run.yaml"
shutil.copy(path, run_config_file)
cprint("Build Successful!", color="green")
cprint("You can find the newly-built template here: " + colored(template_path, "light_blue"))
cprint("Build Successful!", color="green", file=sys.stderr)
cprint(f"You can find the newly-built template here: {template_path}", color="light_blue", file=sys.stderr)
cprint(
"You can run the new Llama Stack distro via: "
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue")
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue"),
color="green",
file=sys.stderr,
)
return template_path
else:

View file

@ -58,8 +58,8 @@ class StackRemove(Subcommand):
"""Display available stacks in a table"""
distributions = self._get_distribution_dirs()
if not distributions:
print("No stacks found in ~/.llama/distributions")
return
cprint("No stacks found in ~/.llama/distributions", color="red", file=sys.stderr)
sys.exit(1)
headers = ["Stack Name", "Path"]
rows = [[name, str(path)] for name, path in distributions.items()]
@ -71,19 +71,20 @@ class StackRemove(Subcommand):
if args.all:
confirm = input("Are you sure you want to delete ALL stacks? [yes-i-really-want/N] ").lower()
if confirm != "yes-i-really-want":
print("Deletion cancelled.")
cprint("Deletion cancelled.", color="green", file=sys.stderr)
return
for name, path in distributions.items():
try:
shutil.rmtree(path)
print(f"Deleted stack: {name}")
cprint(f"Deleted stack: {name}", color="green", file=sys.stderr)
except Exception as e:
cprint(
f"Failed to delete stack {name}: {e}",
color="red",
file=sys.stderr,
)
sys.exit(2)
sys.exit(1)
if not args.name:
self._list_stacks()
@ -95,22 +96,20 @@ class StackRemove(Subcommand):
cprint(
f"Stack not found: {args.name}",
color="red",
file=sys.stderr,
)
return
sys.exit(1)
stack_path = distributions[args.name]
confirm = input(f"Are you sure you want to delete stack '{args.name}'? [y/N] ").lower()
if confirm != "y":
print("Deletion cancelled.")
cprint("Deletion cancelled.", color="green", file=sys.stderr)
return
try:
shutil.rmtree(stack_path)
print(f"Successfully deleted stack: {args.name}")
cprint(f"Successfully deleted stack: {args.name}", color="green", file=sys.stderr)
except Exception as e:
cprint(
f"Failed to delete stack {args.name}: {e}",
color="red",
)
sys.exit(2)
cprint(f"Failed to delete stack {args.name}: {e}", color="red", file=sys.stderr)
sys.exit(1)

View file

@ -6,6 +6,7 @@
import importlib.resources
import logging
import sys
from pathlib import Path
from pydantic import BaseModel
@ -95,10 +96,11 @@ def print_pip_install_help(config: BuildConfig):
cprint(
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
"yellow",
color="yellow",
file=sys.stderr,
)
for special_dep in special_deps:
cprint(f"uv pip install {special_dep}", "yellow")
cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr)
print()

View file

@ -6,6 +6,7 @@
import inspect
import json
import sys
from collections.abc import AsyncIterator
from enum import Enum
from typing import Any, Union, get_args, get_origin
@ -96,13 +97,13 @@ def create_api_client_class(protocol) -> type:
try:
data = json.loads(data)
if "error" in data:
cprint(data, "red")
cprint(data, color="red", file=sys.stderr)
continue
yield parse_obj_as(return_type, data)
except Exception as e:
print(f"Error with parsing or validation: {e}")
print(data)
cprint(f"Error with parsing or validation: {e}", color="red", file=sys.stderr)
cprint(data, color="red", file=sys.stderr)
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
webmethod, sig = self.routes[method_name]

View file

@ -9,6 +9,7 @@ import inspect
import json
import logging
import os
import sys
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
@ -210,10 +211,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.endpoint_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry)
except ModuleNotFoundError as _e:
cprint(_e.msg, "red")
cprint(_e.msg, color="red", file=sys.stderr)
cprint(
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
"yellow",
color="yellow",
file=sys.stderr,
)
if self.config_path_or_template_name.endswith(".yaml"):
# Convert Provider objects to their types
@ -234,6 +236,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
cprint(
f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
"yellow",
file=sys.stderr,
)
cprint(
"Please check your internet connection and try again.",
"red",
file=sys.stderr,
)
raise _e

View file

@ -8,6 +8,7 @@ import logging
import os
import signal
import subprocess
import sys
from termcolor import cprint
@ -33,6 +34,7 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
cprint(
"No current conda environment detected, please specify a conda environment name with --image-name",
color="red",
file=sys.stderr,
)
return
@ -49,12 +51,13 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
return envpath
return None
print(f"Using conda environment: {env_name}")
cprint(f"Using conda environment: {env_name}", color="green", file=sys.stderr)
conda_prefix = get_conda_prefix(env_name)
if not conda_prefix:
cprint(
f"Conda environment {env_name} does not exist.",
color="red",
file=sys.stderr,
)
return
@ -63,6 +66,7 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
cprint(
f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name",
color="red",
file=sys.stderr,
)
return
else:
@ -73,9 +77,10 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
cprint(
"No current virtual environment detected, please specify a virtual environment name with --image-name",
color="red",
file=sys.stderr,
)
return
print(f"Using virtual environment: {env_name}")
cprint(f"Using virtual environment: {env_name}", file=sys.stderr)
script = importlib.resources.files("llama_stack") / "distribution/start_stack.sh"
run_args = [

View file

@ -6,6 +6,7 @@
import logging
import os
import sys
from logging.config import dictConfig
from rich.console import Console
@ -234,7 +235,7 @@ def get_logger(
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
if env_config:
cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", "yellow")
cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", color="yellow", file=sys.stderr)
_category_levels.update(parse_environment_config(env_config))
log_file = os.environ.get("LLAMA_STACK_LOG_FILE")

View file

@ -174,6 +174,7 @@ class Llama3:
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
file=sys.stderr,
)
prompt_tokens = [inp.tokens for inp in llm_inputs]
@ -184,7 +185,11 @@ class Llama3:
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
cprint(
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}",
color="red",
file=sys.stderr,
)
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)

View file

@ -133,9 +133,9 @@ class Llama4:
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input:
cprint("Input to model:\n", "yellow")
cprint("Input to model:\n", color="yellow", file=sys.stderr)
for inp in llm_inputs:
cprint(self.tokenizer.decode(inp.tokens), "grey")
cprint(self.tokenizer.decode(inp.tokens), color="grey", file=sys.stderr)
prompt_tokens = [inp.tokens for inp in llm_inputs]
bsz = len(llm_inputs)
@ -145,7 +145,7 @@ class Llama4:
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", color="red", file=sys.stderr)
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)

View file

@ -6,6 +6,7 @@
import asyncio
import os
import sys
from collections.abc import AsyncGenerator
from pydantic import BaseModel
@ -455,9 +456,9 @@ class MetaReferenceInferenceImpl(
first = token_results[0]
if not first.finished and not first.ignore_token:
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
cprint(first.text, "cyan", end="")
cprint(first.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{first.token}>", "magenta", end="")
cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr)
for result in token_results:
idx = result.batch_idx
@ -519,9 +520,9 @@ class MetaReferenceInferenceImpl(
for token_results in self.generator.chat_completion([request]):
token_result = token_results[0]
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, "cyan", end="")
cprint(token_result.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{token_result.token}>", "magenta", end="")
cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr)
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn