forked from phoenix-oss/llama-stack-mirror
chore: make cprint write to stderr (#2250)
Also do sys.exit(1) in case of errors
This commit is contained in:
parent
c25bd0ad58
commit
5a422e236c
11 changed files with 81 additions and 44 deletions
|
@ -9,6 +9,7 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import sys
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -377,14 +378,15 @@ def _meta_download(
|
||||||
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
||||||
asyncio.run(downloader.download_all(tasks))
|
asyncio.run(downloader.download_all(tasks))
|
||||||
|
|
||||||
cprint(f"\nSuccessfully downloaded model to {output_dir}", "green")
|
cprint(f"\nSuccessfully downloaded model to {output_dir}", color="green", file=sys.stderr)
|
||||||
cprint(
|
cprint(
|
||||||
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
|
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
|
||||||
"white",
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
cprint(
|
cprint(
|
||||||
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
|
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
|
||||||
"yellow",
|
color="yellow",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -79,6 +79,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
|
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
build_config = available_templates[args.template]
|
build_config = available_templates[args.template]
|
||||||
|
@ -88,6 +89,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {args.template}",
|
f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {args.template}",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
elif args.providers:
|
elif args.providers:
|
||||||
|
@ -97,6 +99,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
api, provider = api_provider.split("=")
|
api, provider = api_provider.split("=")
|
||||||
|
@ -105,6 +108,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"{api} is not a valid API.",
|
f"{api} is not a valid API.",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
if provider in providers_for_api:
|
if provider in providers_for_api:
|
||||||
|
@ -113,6 +117,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"{provider} is not a valid provider for the {api} API.",
|
f"{provider} is not a valid provider for the {api} API.",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
distribution_spec = DistributionSpec(
|
distribution_spec = DistributionSpec(
|
||||||
|
@ -123,6 +128,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
@ -151,12 +157,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"No current conda environment detected or specified, will create a new conda environment with the name `llamastack-{name}`",
|
f"No current conda environment detected or specified, will create a new conda environment with the name `llamastack-{name}`",
|
||||||
color="yellow",
|
color="yellow",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
image_name = f"llamastack-{name}"
|
image_name = f"llamastack-{name}"
|
||||||
else:
|
else:
|
||||||
cprint(
|
cprint(
|
||||||
f"Using conda environment {image_name}",
|
f"Using conda environment {image_name}",
|
||||||
color="green",
|
color="green",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image_name = f"llamastack-{name}"
|
image_name = f"llamastack-{name}"
|
||||||
|
@ -169,9 +177,10 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
color="green",
|
color="green",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Tip: use <TAB> to see options for the providers.\n")
|
cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr)
|
||||||
|
|
||||||
providers = dict()
|
providers = dict()
|
||||||
for api, providers_for_api in get_provider_registry().items():
|
for api, providers_for_api in get_provider_registry().items():
|
||||||
|
@ -213,6 +222,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"Could not parse config file {args.config}: {e}",
|
f"Could not parse config file {args.config}: {e}",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
@ -239,14 +249,17 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
cprint(
|
cprint(
|
||||||
f"Error building stack: {exc}",
|
f"Error building stack: {exc}",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
cprint("Stack trace:", color="red")
|
cprint("Stack trace:", color="red", file=sys.stderr)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if run_config is None:
|
if run_config is None:
|
||||||
cprint(
|
cprint(
|
||||||
"Run config path is empty",
|
"Run config path is empty",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
@ -304,6 +317,7 @@ def _generate_run_config(
|
||||||
cprint(
|
cprint(
|
||||||
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
|
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
|
||||||
color="yellow",
|
color="yellow",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
# Set config_type to None to avoid UnboundLocalError
|
# Set config_type to None to avoid UnboundLocalError
|
||||||
config_type = None
|
config_type = None
|
||||||
|
@ -331,10 +345,7 @@ def _generate_run_config(
|
||||||
# For non-container builds, the run.yaml is generated at the very end of the build process so it
|
# For non-container builds, the run.yaml is generated at the very end of the build process so it
|
||||||
# makes sense to display this message
|
# makes sense to display this message
|
||||||
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
|
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
|
||||||
cprint(
|
cprint(f"You can now run your stack with `llama stack run {run_config_file}`", color="green", file=sys.stderr)
|
||||||
f"You can now run your stack with `llama stack run {run_config_file}`",
|
|
||||||
color="green",
|
|
||||||
)
|
|
||||||
return run_config_file
|
return run_config_file
|
||||||
|
|
||||||
|
|
||||||
|
@ -372,7 +383,7 @@ def _run_stack_build_command_from_build_config(
|
||||||
# Generate the run.yaml so it can be included in the container image with the proper entrypoint
|
# Generate the run.yaml so it can be included in the container image with the proper entrypoint
|
||||||
# Only do this if we're building a container image and we're not using a template
|
# Only do this if we're building a container image and we're not using a template
|
||||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path:
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path:
|
||||||
cprint("Generating run.yaml file", color="green")
|
cprint("Generating run.yaml file", color="yellow", file=sys.stderr)
|
||||||
run_config_file = _generate_run_config(build_config, build_dir, image_name)
|
run_config_file = _generate_run_config(build_config, build_dir, image_name)
|
||||||
|
|
||||||
with open(build_file_path, "w") as f:
|
with open(build_file_path, "w") as f:
|
||||||
|
@ -396,11 +407,13 @@ def _run_stack_build_command_from_build_config(
|
||||||
run_config_file = build_dir / f"{template_name}-run.yaml"
|
run_config_file = build_dir / f"{template_name}-run.yaml"
|
||||||
shutil.copy(path, run_config_file)
|
shutil.copy(path, run_config_file)
|
||||||
|
|
||||||
cprint("Build Successful!", color="green")
|
cprint("Build Successful!", color="green", file=sys.stderr)
|
||||||
cprint("You can find the newly-built template here: " + colored(template_path, "light_blue"))
|
cprint(f"You can find the newly-built template here: {template_path}", color="light_blue", file=sys.stderr)
|
||||||
cprint(
|
cprint(
|
||||||
"You can run the new Llama Stack distro via: "
|
"You can run the new Llama Stack distro via: "
|
||||||
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue")
|
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue"),
|
||||||
|
color="green",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
return template_path
|
return template_path
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -58,8 +58,8 @@ class StackRemove(Subcommand):
|
||||||
"""Display available stacks in a table"""
|
"""Display available stacks in a table"""
|
||||||
distributions = self._get_distribution_dirs()
|
distributions = self._get_distribution_dirs()
|
||||||
if not distributions:
|
if not distributions:
|
||||||
print("No stacks found in ~/.llama/distributions")
|
cprint("No stacks found in ~/.llama/distributions", color="red", file=sys.stderr)
|
||||||
return
|
sys.exit(1)
|
||||||
|
|
||||||
headers = ["Stack Name", "Path"]
|
headers = ["Stack Name", "Path"]
|
||||||
rows = [[name, str(path)] for name, path in distributions.items()]
|
rows = [[name, str(path)] for name, path in distributions.items()]
|
||||||
|
@ -71,19 +71,20 @@ class StackRemove(Subcommand):
|
||||||
if args.all:
|
if args.all:
|
||||||
confirm = input("Are you sure you want to delete ALL stacks? [yes-i-really-want/N] ").lower()
|
confirm = input("Are you sure you want to delete ALL stacks? [yes-i-really-want/N] ").lower()
|
||||||
if confirm != "yes-i-really-want":
|
if confirm != "yes-i-really-want":
|
||||||
print("Deletion cancelled.")
|
cprint("Deletion cancelled.", color="green", file=sys.stderr)
|
||||||
return
|
return
|
||||||
|
|
||||||
for name, path in distributions.items():
|
for name, path in distributions.items():
|
||||||
try:
|
try:
|
||||||
shutil.rmtree(path)
|
shutil.rmtree(path)
|
||||||
print(f"Deleted stack: {name}")
|
cprint(f"Deleted stack: {name}", color="green", file=sys.stderr)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cprint(
|
cprint(
|
||||||
f"Failed to delete stack {name}: {e}",
|
f"Failed to delete stack {name}: {e}",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sys.exit(2)
|
sys.exit(1)
|
||||||
|
|
||||||
if not args.name:
|
if not args.name:
|
||||||
self._list_stacks()
|
self._list_stacks()
|
||||||
|
@ -95,22 +96,20 @@ class StackRemove(Subcommand):
|
||||||
cprint(
|
cprint(
|
||||||
f"Stack not found: {args.name}",
|
f"Stack not found: {args.name}",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
return
|
sys.exit(1)
|
||||||
|
|
||||||
stack_path = distributions[args.name]
|
stack_path = distributions[args.name]
|
||||||
|
|
||||||
confirm = input(f"Are you sure you want to delete stack '{args.name}'? [y/N] ").lower()
|
confirm = input(f"Are you sure you want to delete stack '{args.name}'? [y/N] ").lower()
|
||||||
if confirm != "y":
|
if confirm != "y":
|
||||||
print("Deletion cancelled.")
|
cprint("Deletion cancelled.", color="green", file=sys.stderr)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
shutil.rmtree(stack_path)
|
shutil.rmtree(stack_path)
|
||||||
print(f"Successfully deleted stack: {args.name}")
|
cprint(f"Successfully deleted stack: {args.name}", color="green", file=sys.stderr)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cprint(
|
cprint(f"Failed to delete stack {args.name}: {e}", color="red", file=sys.stderr)
|
||||||
f"Failed to delete stack {args.name}: {e}",
|
sys.exit(1)
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
sys.exit(2)
|
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -95,10 +96,11 @@ def print_pip_install_help(config: BuildConfig):
|
||||||
|
|
||||||
cprint(
|
cprint(
|
||||||
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
|
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
|
||||||
"yellow",
|
color="yellow",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
for special_dep in special_deps:
|
for special_dep in special_deps:
|
||||||
cprint(f"uv pip install {special_dep}", "yellow")
|
cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Union, get_args, get_origin
|
from typing import Any, Union, get_args, get_origin
|
||||||
|
@ -96,13 +97,13 @@ def create_api_client_class(protocol) -> type:
|
||||||
try:
|
try:
|
||||||
data = json.loads(data)
|
data = json.loads(data)
|
||||||
if "error" in data:
|
if "error" in data:
|
||||||
cprint(data, "red")
|
cprint(data, color="red", file=sys.stderr)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield parse_obj_as(return_type, data)
|
yield parse_obj_as(return_type, data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error with parsing or validation: {e}")
|
cprint(f"Error with parsing or validation: {e}", color="red", file=sys.stderr)
|
||||||
print(data)
|
cprint(data, color="red", file=sys.stderr)
|
||||||
|
|
||||||
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
|
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
|
||||||
webmethod, sig = self.routes[method_name]
|
webmethod, sig = self.routes[method_name]
|
||||||
|
|
|
@ -9,6 +9,7 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -210,10 +211,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
self.endpoint_impls = None
|
self.endpoint_impls = None
|
||||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||||
except ModuleNotFoundError as _e:
|
except ModuleNotFoundError as _e:
|
||||||
cprint(_e.msg, "red")
|
cprint(_e.msg, color="red", file=sys.stderr)
|
||||||
cprint(
|
cprint(
|
||||||
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
|
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
|
||||||
"yellow",
|
color="yellow",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
if self.config_path_or_template_name.endswith(".yaml"):
|
if self.config_path_or_template_name.endswith(".yaml"):
|
||||||
# Convert Provider objects to their types
|
# Convert Provider objects to their types
|
||||||
|
@ -234,7 +236,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
cprint(
|
cprint(
|
||||||
f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
|
f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
|
||||||
"yellow",
|
"yellow",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
|
cprint(
|
||||||
|
"Please check your internet connection and try again.",
|
||||||
|
"red",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
raise _e
|
raise _e
|
||||||
|
|
||||||
if Api.telemetry in self.impls:
|
if Api.telemetry in self.impls:
|
||||||
|
|
|
@ -8,6 +8,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
@ -33,6 +34,7 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||||
cprint(
|
cprint(
|
||||||
"No current conda environment detected, please specify a conda environment name with --image-name",
|
"No current conda environment detected, please specify a conda environment name with --image-name",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -49,12 +51,13 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||||
return envpath
|
return envpath
|
||||||
return None
|
return None
|
||||||
|
|
||||||
print(f"Using conda environment: {env_name}")
|
cprint(f"Using conda environment: {env_name}", color="green", file=sys.stderr)
|
||||||
conda_prefix = get_conda_prefix(env_name)
|
conda_prefix = get_conda_prefix(env_name)
|
||||||
if not conda_prefix:
|
if not conda_prefix:
|
||||||
cprint(
|
cprint(
|
||||||
f"Conda environment {env_name} does not exist.",
|
f"Conda environment {env_name} does not exist.",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -63,6 +66,7 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||||
cprint(
|
cprint(
|
||||||
f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name",
|
f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
|
@ -73,9 +77,10 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||||
cprint(
|
cprint(
|
||||||
"No current virtual environment detected, please specify a virtual environment name with --image-name",
|
"No current virtual environment detected, please specify a virtual environment name with --image-name",
|
||||||
color="red",
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
print(f"Using virtual environment: {env_name}")
|
cprint(f"Using virtual environment: {env_name}", file=sys.stderr)
|
||||||
|
|
||||||
script = importlib.resources.files("llama_stack") / "distribution/start_stack.sh"
|
script = importlib.resources.files("llama_stack") / "distribution/start_stack.sh"
|
||||||
run_args = [
|
run_args = [
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from logging.config import dictConfig
|
from logging.config import dictConfig
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
@ -234,7 +235,7 @@ def get_logger(
|
||||||
|
|
||||||
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
||||||
if env_config:
|
if env_config:
|
||||||
cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", "yellow")
|
cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", color="yellow", file=sys.stderr)
|
||||||
_category_levels.update(parse_environment_config(env_config))
|
_category_levels.update(parse_environment_config(env_config))
|
||||||
|
|
||||||
log_file = os.environ.get("LLAMA_STACK_LOG_FILE")
|
log_file = os.environ.get("LLAMA_STACK_LOG_FILE")
|
||||||
|
|
|
@ -174,6 +174,7 @@ class Llama3:
|
||||||
cprint(
|
cprint(
|
||||||
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
||||||
"red",
|
"red",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
||||||
|
|
||||||
|
@ -184,7 +185,11 @@ class Llama3:
|
||||||
max_prompt_len = max(len(t) for t in prompt_tokens)
|
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||||
|
|
||||||
if max_prompt_len >= params.max_seq_len:
|
if max_prompt_len >= params.max_seq_len:
|
||||||
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
|
cprint(
|
||||||
|
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}",
|
||||||
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||||
|
|
|
@ -133,9 +133,9 @@ class Llama4:
|
||||||
|
|
||||||
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
||||||
if print_model_input:
|
if print_model_input:
|
||||||
cprint("Input to model:\n", "yellow")
|
cprint("Input to model:\n", color="yellow", file=sys.stderr)
|
||||||
for inp in llm_inputs:
|
for inp in llm_inputs:
|
||||||
cprint(self.tokenizer.decode(inp.tokens), "grey")
|
cprint(self.tokenizer.decode(inp.tokens), color="grey", file=sys.stderr)
|
||||||
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
||||||
|
|
||||||
bsz = len(llm_inputs)
|
bsz = len(llm_inputs)
|
||||||
|
@ -145,7 +145,7 @@ class Llama4:
|
||||||
max_prompt_len = max(len(t) for t in prompt_tokens)
|
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||||
|
|
||||||
if max_prompt_len >= params.max_seq_len:
|
if max_prompt_len >= params.max_seq_len:
|
||||||
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
|
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", color="red", file=sys.stderr)
|
||||||
return
|
return
|
||||||
|
|
||||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -455,9 +456,9 @@ class MetaReferenceInferenceImpl(
|
||||||
first = token_results[0]
|
first = token_results[0]
|
||||||
if not first.finished and not first.ignore_token:
|
if not first.finished and not first.ignore_token:
|
||||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
|
||||||
cprint(first.text, "cyan", end="")
|
cprint(first.text, color="cyan", end="", file=sys.stderr)
|
||||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||||
cprint(f"<{first.token}>", "magenta", end="")
|
cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr)
|
||||||
|
|
||||||
for result in token_results:
|
for result in token_results:
|
||||||
idx = result.batch_idx
|
idx = result.batch_idx
|
||||||
|
@ -519,9 +520,9 @@ class MetaReferenceInferenceImpl(
|
||||||
for token_results in self.generator.chat_completion([request]):
|
for token_results in self.generator.chat_completion([request]):
|
||||||
token_result = token_results[0]
|
token_result = token_results[0]
|
||||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||||
cprint(token_result.text, "cyan", end="")
|
cprint(token_result.text, color="cyan", end="", file=sys.stderr)
|
||||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||||
cprint(f"<{token_result.token}>", "magenta", end="")
|
cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr)
|
||||||
|
|
||||||
if token_result.token == tokenizer.eot_id:
|
if token_result.token == tokenizer.eot_id:
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue