From 37cf60b73292468775dbfc876e7838fb1b7ccf96 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 18 Feb 2025 19:41:37 -0800 Subject: [PATCH] style: remove prints in codebase (#1146) # What does this PR do? - replace prints in codebase with logger - update print_table to use rich Table ## Test Plan - library client script in https://github.com/meta-llama/llama-stack/pull/1145 ``` llama stack list-providers ``` image [//]: # (## Documentation) --- llama_stack/cli/table.py | 75 +++++-------------- llama_stack/distribution/library_client.py | 11 +-- .../remote/inference/nvidia/nvidia.py | 12 ++- .../remote/inference/nvidia/utils.py | 5 +- 4 files changed, 38 insertions(+), 65 deletions(-) diff --git a/llama_stack/cli/table.py b/llama_stack/cli/table.py index 599749231..bf59e6103 100644 --- a/llama_stack/cli/table.py +++ b/llama_stack/cli/table.py @@ -4,75 +4,36 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import re -import textwrap from typing import Iterable -from termcolor import cprint - - -def strip_ansi_colors(text): - ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") - return ansi_escape.sub("", text) - - -def format_row(row, col_widths): - def wrap(text, width): - lines = [] - for line in text.split("\n"): - if line.strip() == "": - lines.append("") - else: - lines.extend(textwrap.wrap(line, width, break_long_words=False, replace_whitespace=False)) - return lines - - wrapped = [wrap(item, width) for item, width in zip(row, col_widths, strict=False)] - max_lines = max(len(subrow) for subrow in wrapped) - - lines = [] - for i in range(max_lines): - line = [] - for cell_lines, width in zip(wrapped, col_widths, strict=False): - value = cell_lines[i] if i < len(cell_lines) else "" - line.append(value + " " * (width - len(strip_ansi_colors(value)))) - lines.append("| " + (" | ".join(line)) + " |") - - return "\n".join(lines) +from rich.console import Console +from rich.table import Table def print_table(rows, headers=None, separate_rows: bool = False, sort_by: Iterable[int] = tuple()): - def itemlen(item): - return max([len(line) for line in strip_ansi_colors(item).split("\n")]) - + # Convert rows and handle None values rows = [[x or "" for x in row] for row in rows] + # Sort rows if sort_by is specified if sort_by: rows.sort(key=lambda x: tuple(x[i] for i in sort_by)) - if not headers: - col_widths = [max(itemlen(item) for item in col) for col in zip(*rows, strict=False)] - else: - col_widths = [ - max( - itemlen(header), - max(itemlen(item) for item in col), - ) - for header, col in zip(headers, zip(*rows, strict=False), strict=False) - ] - col_widths = [min(w, 80) for w in col_widths] - - header_line = "+".join("-" * (width + 2) for width in col_widths) - header_line = f"+{header_line}+" + # Create Rich table + table = Table(show_lines=separate_rows) + # Add headers if provided if headers: - print(header_line) - cprint(format_row(headers, col_widths), "white", attrs=["bold"]) + for header in headers: + table.add_column(header, style="bold white") + else: + # Add unnamed columns based on first row + for _ in range(len(rows[0]) if rows else 0): + table.add_column() - print(header_line) + # Add rows for row in rows: - print(format_row(row, col_widths)) - if separate_rows: - print(header_line) + table.add_row(*row) - if not separate_rows: - print(header_line) + # Print table + console = Console() + console.print(table) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index a40651551..639e5ee73 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -47,6 +47,8 @@ from llama_stack.providers.utils.telemetry.tracing import ( start_trace, ) +logger = logging.getLogger(__name__) + T = TypeVar("T") @@ -87,7 +89,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any: try: return [convert_to_pydantic(item_type, item) for item in value] except Exception: - print(f"Error converting list {value} into {item_type}") + logger.error(f"Error converting list {value} into {item_type}") return value elif origin is dict: @@ -95,7 +97,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any: try: return {k: convert_to_pydantic(val_type, v) for k, v in value.items()} except Exception: - print(f"Error converting dict {value} into {val_type}") + logger.error(f"Error converting dict {value} into {val_type}") return value try: @@ -111,9 +113,8 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any: return convert_to_pydantic(union_type, value) except Exception: continue - cprint( + logger.warning( f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}", - "yellow", ) return value @@ -152,7 +153,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): for handler in root_logger.handlers[:]: root_logger.removeHandler(handler) - print(f"Removed handler {handler.__class__.__name__} from root logger") + logger.info(f"Removed handler {handler.__class__.__name__} from root logger") def request(self, *args, **kwargs): if kwargs.get("stream"): diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 0c5b7c454..8e67333af 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging import warnings from typing import AsyncIterator, List, Optional, Union @@ -25,7 +26,12 @@ from llama_stack.apis.inference import ( ToolChoice, ToolConfig, ) -from llama_stack.models.llama.datatypes import CoreModelId, SamplingParams, ToolDefinition, ToolPromptFormat +from llama_stack.models.llama.datatypes import ( + CoreModelId, + SamplingParams, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, @@ -43,6 +49,8 @@ from .openai_utils import ( ) from .utils import _is_nvidia_hosted, check_health +logger = logging.getLogger(__name__) + _MODEL_ALIASES = [ build_model_alias( "meta/llama3-8b-instruct", @@ -90,7 +98,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): # TODO(mf): filter by available models ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES) - print(f"Initializing NVIDIAInferenceAdapter({config.url})...") + logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...") if _is_nvidia_hosted(config): if not config.api_key: diff --git a/llama_stack/providers/remote/inference/nvidia/utils.py b/llama_stack/providers/remote/inference/nvidia/utils.py index 0ec80e9dd..7d3f3f27e 100644 --- a/llama_stack/providers/remote/inference/nvidia/utils.py +++ b/llama_stack/providers/remote/inference/nvidia/utils.py @@ -4,12 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging from typing import Tuple import httpx from . import NVIDIAConfig +logger = logging.getLogger(__name__) + def _is_nvidia_hosted(config: NVIDIAConfig) -> bool: return "integrate.api.nvidia.com" in config.url @@ -42,7 +45,7 @@ async def check_health(config: NVIDIAConfig) -> None: RuntimeError: If the server is not running or ready """ if not _is_nvidia_hosted(config): - print("Checking NVIDIA NIM health...") + logger.info("Checking NVIDIA NIM health...") try: is_live, is_ready = await _get_health(config.url) if not is_live: