diff --git a/llama_stack/apis/agents/client.py b/llama_stack/apis/agents/client.py index b45447328..1726e5455 100644 --- a/llama_stack/apis/agents/client.py +++ b/llama_stack/apis/agents/client.py @@ -14,15 +14,19 @@ import httpx from dotenv import load_dotenv from pydantic import BaseModel -from termcolor import cprint from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import RemoteProviderConfig from .agents import * # noqa: F403 +import logging + from .event_logger import EventLogger +log = logging.getLogger(__name__) + + load_dotenv() @@ -93,13 +97,12 @@ class AgentsClient(Agents): try: jdata = json.loads(data) if "error" in jdata: - cprint(data, "red") + log.error(data) continue yield AgentTurnResponseStreamChunk(**jdata) except Exception as e: - print(data) - print(f"Error with parsing or validation: {e}") + log.error(f"Error with parsing or validation: {e}") async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest): raise NotImplementedError("Non-streaming not implemented yet") @@ -125,7 +128,7 @@ async def _run_agent( ) for content in user_prompts: - cprint(f"User> {content}", color="white", attrs=["bold"]) + log.info(f"User> {content}", color="white", attrs=["bold"]) iterator = await api.create_agent_turn( AgentTurnCreateRequest( agent_id=create_response.agent_id, @@ -138,9 +141,9 @@ async def _run_agent( ) ) - async for event, log in EventLogger().log(iterator): - if log is not None: - log.print() + async for event, logger in EventLogger().log(iterator): + if logger is not None: + log.info(logger) async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"): diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index 92e33b9fd..19b358a77 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -4,14 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging from enum import Enum from typing import List import pkg_resources from pydantic import BaseModel -from termcolor import cprint - from llama_stack.distribution.utils.exec import run_with_pty from llama_stack.distribution.datatypes import * # noqa: F403 @@ -22,6 +21,8 @@ from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR +log = logging.getLogger(__name__) + # These are the dependencies needed by the distribution server. # `llama-stack` is automatically installed by the installation script. SERVER_DEPENDENCIES = [ @@ -89,12 +90,12 @@ def get_provider_dependencies( def print_pip_install_help(providers: Dict[str, List[Provider]]): normal_deps, special_deps = get_provider_dependencies(providers) - print( + log.info( f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}" ) for special_dep in special_deps: - print(f"\tpip install {special_dep}") - print() + log.info(f"\tpip install {special_dep}") + log.info() def build_image(build_config: BuildConfig, build_file_path: Path): @@ -133,9 +134,8 @@ def build_image(build_config: BuildConfig, build_file_path: Path): return_code = run_with_pty(args) if return_code != 0: - cprint( + log.error( f"Failed to build target {build_config.name} with return code {return_code}", - color="red", ) return return_code diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index 09e277dad..a4d0f970b 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -3,12 +3,12 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging import textwrap from typing import Any from llama_stack.distribution.datatypes import * # noqa: F403 -from termcolor import cprint from llama_stack.distribution.distribution import ( builtin_automatically_routed_apis, @@ -22,6 +22,8 @@ from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 +logger = logging.getLogger(__name__) + def configure_single_provider( registry: Dict[str, ProviderSpec], provider: Provider @@ -50,7 +52,7 @@ def configure_api_providers( is_nux = len(config.providers) == 0 if is_nux: - print( + logger.info( textwrap.dedent( """ Llama Stack is composed of several APIs working together. For each API served by the Stack, @@ -76,18 +78,18 @@ def configure_api_providers( existing_providers = config.providers.get(api_str, []) if existing_providers: - cprint( + logger.info( f"Re-configuring existing providers for API `{api_str}`...", "green", attrs=["bold"], ) updated_providers = [] for p in existing_providers: - print(f"> Configuring provider `({p.provider_type})`") + logger.info(f"> Configuring provider `({p.provider_type})`") updated_providers.append( configure_single_provider(provider_registry[api], p) ) - print("") + logger.info("") else: # we are newly configuring this API plist = build_spec.providers.get(api_str, []) @@ -96,17 +98,17 @@ def configure_api_providers( if not plist: raise ValueError(f"No provider configured for API {api_str}?") - cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"]) + logger.info(f"Configuring API `{api_str}`...", "green", attrs=["bold"]) updated_providers = [] for i, provider_type in enumerate(plist): if i >= 1: others = ", ".join(plist[i:]) - print( + logger.info( f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n" ) break - print(f"> Configuring provider `({provider_type})`") + logger.info(f"> Configuring provider `({provider_type})`") updated_providers.append( configure_single_provider( provider_registry[api], @@ -121,7 +123,7 @@ def configure_api_providers( ), ) ) - print("") + logger.info("") config.providers[api_str] = updated_providers @@ -182,7 +184,7 @@ def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfi return StackRunConfig(**config_dict) if "routing_table" in config_dict: - print("Upgrading config...") + logger.info("Upgrading config...") config_dict = upgrade_from_routing_table(config_dict) config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index bbb1fff9d..27ef3046a 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -5,11 +5,14 @@ # the root directory of this source tree. import json +import logging import threading from typing import Any, Dict from .utils.dynamic import instantiate_class_type +log = logging.getLogger(__name__) + _THREAD_LOCAL = threading.local() @@ -32,7 +35,7 @@ class NeedsRequestProviderData: provider_data = validator(**val) return provider_data except Exception as e: - print("Error parsing provider data", e) + log.error("Error parsing provider data", e) def set_request_provider_data(headers: Dict[str, str]): @@ -51,7 +54,7 @@ def set_request_provider_data(headers: Dict[str, str]): try: val = json.loads(val) except json.JSONDecodeError: - print("Provider data not encoded as a JSON object!", val) + log.error("Provider data not encoded as a JSON object!", val) return _THREAD_LOCAL.provider_data_header_value = val diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 4c74b0d1f..aa18de15b 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -8,11 +8,12 @@ import inspect from typing import Any, Dict, List, Set -from termcolor import cprint from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 +import logging + from llama_stack.apis.agents import Agents from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets @@ -33,6 +34,8 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_a from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type +log = logging.getLogger(__name__) + class InvalidProviderError(Exception): pass @@ -115,11 +118,11 @@ async def resolve_impls( p = provider_registry[api][provider.provider_type] if p.deprecation_error: - cprint(p.deprecation_error, "red", attrs=["bold"]) + log.error(p.deprecation_error, "red", attrs=["bold"]) raise InvalidProviderError(p.deprecation_error) elif p.deprecation_warning: - cprint( + log.warning( f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", "yellow", attrs=["bold"], @@ -199,10 +202,10 @@ async def resolve_impls( ) ) - print(f"Resolved {len(sorted_providers)} providers") + log.info(f"Resolved {len(sorted_providers)} providers") for api_str, provider in sorted_providers: - print(f" {api_str} => {provider.provider_id}") - print("") + log.info(f" {api_str} => {provider.provider_id}") + log.info("") impls = {} inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis} @@ -339,7 +342,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None: obj_params = set(obj_sig.parameters) obj_params.discard("self") if not (proto_params <= obj_params): - print( + log.error( f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}" ) missing_methods.append((name, "signature_mismatch")) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index f0d91f3a6..b8ff0e785 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -46,6 +46,10 @@ from llama_stack.distribution.stack import ( replace_env_vars, validate_env_pair, ) +from llama_stack.providers.inline.meta_reference.telemetry.console import ( + ConsoleConfig, + ConsoleTelemetryImpl, +) from .endpoints import get_all_api_endpoints @@ -196,7 +200,6 @@ def handle_sigint(app, *args, **kwargs): async def lifespan(app: FastAPI): print("Starting up") yield - print("Shutting down") for impl in app.__llama_stack_impls__.values(): await impl.shutdown() @@ -214,6 +217,7 @@ async def maybe_await(value): async def sse_generator(event_gen): + await start_trace("sse_generator") try: event_gen = await event_gen async for item in event_gen: @@ -333,7 +337,7 @@ def main(): print("Run configuration:") print(yaml.dump(config.model_dump(), indent=2)) - app = FastAPI() + app = FastAPI(lifespan=lifespan) try: impls = asyncio.run(construct_stack(config)) @@ -342,6 +346,8 @@ def main(): if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) + else: + setup_logger(ConsoleTelemetryImpl(ConsoleConfig())) all_endpoints = get_all_api_endpoints() diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 9bd058400..75126c221 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging import os from pathlib import Path from typing import Any, Dict @@ -40,6 +41,8 @@ from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.providers.datatypes import Api +log = logging.getLogger(__name__) + LLAMA_STACK_API_VERSION = "alpha" @@ -93,11 +96,11 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]): method = getattr(impls[api], list_method) for obj in await method(): - print( + log.info( f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}", ) - print("") + log.info("") class EnvVarError(Exception): diff --git a/llama_stack/distribution/utils/exec.py b/llama_stack/distribution/utils/exec.py index a01a1cf80..7b06e384d 100644 --- a/llama_stack/distribution/utils/exec.py +++ b/llama_stack/distribution/utils/exec.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import errno +import logging import os import pty import select @@ -13,7 +14,7 @@ import subprocess import sys import termios -from termcolor import cprint +log = logging.getLogger(__name__) # run a command in a pseudo-terminal, with interrupt handling, @@ -29,7 +30,7 @@ def run_with_pty(command): def sigint_handler(signum, frame): nonlocal ctrl_c_pressed ctrl_c_pressed = True - cprint("\nCtrl-C detected. Aborting...", "white", attrs=["bold"]) + log.info("\nCtrl-C detected. Aborting...") try: # Set up the signal handler @@ -100,6 +101,6 @@ def run_command(command): process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) output, error = process.communicate() if process.returncode != 0: - print(f"Error: {error.decode('utf-8')}") + log.error(f"Error: {error.decode('utf-8')}") sys.exit(1) return output.decode("utf-8") diff --git a/llama_stack/distribution/utils/prompt_for_config.py b/llama_stack/distribution/utils/prompt_for_config.py index 54e9e9cc3..2eec655b1 100644 --- a/llama_stack/distribution/utils/prompt_for_config.py +++ b/llama_stack/distribution/utils/prompt_for_config.py @@ -6,6 +6,7 @@ import inspect import json +import logging from enum import Enum from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union @@ -16,6 +17,8 @@ from pydantic_core import PydanticUndefinedType from typing_extensions import Annotated +log = logging.getLogger(__name__) + def is_list_of_primitives(field_type): """Check if a field type is a List of primitive types.""" @@ -111,7 +114,7 @@ def prompt_for_discriminated_union( if discriminator_value in type_map: chosen_type = type_map[discriminator_value] - print(f"\nConfiguring {chosen_type.__name__}:") + log.info(f"\nConfiguring {chosen_type.__name__}:") if existing_value and ( getattr(existing_value, discriminator) != discriminator_value @@ -123,7 +126,7 @@ def prompt_for_discriminated_union( setattr(sub_config, discriminator, discriminator_value) return sub_config else: - print(f"Invalid {discriminator}. Please try again.") + log.error(f"Invalid {discriminator}. Please try again.") # This is somewhat elaborate, but does not purport to be comprehensive in any way. @@ -180,7 +183,7 @@ def prompt_for_config( config_data[field_name] = validated_value break except KeyError: - print( + log.error( f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}" ) continue @@ -197,7 +200,7 @@ def prompt_for_config( config_data[field_name] = None continue nested_type = get_non_none_type(field_type) - print(f"Entering sub-configuration for {field_name}:") + log.info(f"Entering sub-configuration for {field_name}:") config_data[field_name] = prompt_for_config(nested_type, existing_value) elif is_optional(field_type) and is_discriminated_union( get_non_none_type(field_type) @@ -213,7 +216,7 @@ def prompt_for_config( existing_value, ) elif can_recurse(field_type): - print(f"\nEntering sub-configuration for {field_name}:") + log.info(f"\nEntering sub-configuration for {field_name}:") config_data[field_name] = prompt_for_config( field_type, existing_value, @@ -240,7 +243,7 @@ def prompt_for_config( config_data[field_name] = None break else: - print("This field is required. Please provide a value.") + log.error("This field is required. Please provide a value.") continue else: try: @@ -264,12 +267,12 @@ def prompt_for_config( value = [element_type(item) for item in value] except json.JSONDecodeError: - print( + log.error( 'Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]' ) continue except ValueError as e: - print(f"{str(e)}") + log.error(f"{str(e)}") continue elif get_origin(field_type) is dict: @@ -281,7 +284,7 @@ def prompt_for_config( ) except json.JSONDecodeError: - print( + log.error( "Invalid JSON. Please enter a valid JSON-encoded dict." ) continue @@ -298,7 +301,7 @@ def prompt_for_config( value = field_type(user_input) except ValueError: - print( + log.error( f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}" ) continue @@ -311,6 +314,6 @@ def prompt_for_config( config_data[field_name] = validated_value break except ValueError as e: - print(f"Validation error: {str(e)}") + log.error(f"Validation error: {str(e)}") return config_type(**config_data) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 0c15b1b5e..6d7fb95c1 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -6,6 +6,7 @@ import asyncio import copy +import logging import os import re import secrets @@ -19,7 +20,6 @@ from urllib.parse import urlparse import httpx -from termcolor import cprint from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 @@ -43,6 +43,8 @@ from .tools.builtin import ( ) from .tools.safety import SafeTool +log = logging.getLogger(__name__) + def make_random_string(length: int = 8): return "".join( @@ -137,7 +139,6 @@ class ChatAgent(ShieldRunnerMixin): stop_reason=StopReason.end_of_turn, ) ) - # print_dialog(messages) return messages async def create_session(self, name: str) -> str: @@ -185,10 +186,8 @@ class ChatAgent(ShieldRunnerMixin): stream=request.stream, ): if isinstance(chunk, CompletionMessage): - cprint( + log.info( f"{chunk.role.capitalize()}: {chunk.content}", - "white", - attrs=["bold"], ) output_message = chunk continue @@ -407,7 +406,7 @@ class ChatAgent(ShieldRunnerMixin): msg_str = f"{str(msg)[:500]}......{str(msg)[-500:]}" else: msg_str = str(msg) - cprint(f"{msg_str}", color=color) + log.info(f"{msg_str}") step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( @@ -506,12 +505,12 @@ class ChatAgent(ShieldRunnerMixin): ) if n_iter >= self.agent_config.max_infer_iters: - cprint("Done with MAX iterations, exiting.") + log.info("Done with MAX iterations, exiting.") yield message break if stop_reason == StopReason.out_of_tokens: - cprint("Out of token budget, exiting.") + log.info("Out of token budget, exiting.") yield message break @@ -525,10 +524,10 @@ class ChatAgent(ShieldRunnerMixin): message.content = [message.content] + attachments yield message else: - cprint(f"Partial message: {str(message)}", color="green") + log.info(f"Partial message: {str(message)}", color="green") input_messages = input_messages + [message] else: - cprint(f"{str(message)}", color="green") + log.info(f"{str(message)}", color="green") try: tool_call = message.tool_calls[0] @@ -740,9 +739,8 @@ class ChatAgent(ShieldRunnerMixin): for c in chunks[: memory.max_chunks]: tokens += c.token_count if tokens > memory.max_tokens_in_context: - cprint( + log.error( f"Using {len(picked)} chunks; reached max tokens in context: {tokens}", - "red", ) break picked.append(f"id:{c.document_id}; content:{c.content}") @@ -786,7 +784,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa path = urlparse(uri).path basename = os.path.basename(path) filepath = f"{tempdir}/{make_random_string() + basename}" - print(f"Downloading {url} -> {filepath}") + log.info(f"Downloading {url} -> {filepath}") async with httpx.AsyncClient() as client: r = await client.get(uri) @@ -826,20 +824,3 @@ async def execute_tool_call_maybe( tool = tools_dict[name] result_messages = await tool.run(messages) return result_messages - - -def print_dialog(messages: List[Message]): - for i, m in enumerate(messages): - if m.role == Role.user.value: - color = "red" - elif m.role == Role.assistant.value: - color = "white" - elif m.role == Role.ipython.value: - color = "yellow" - elif m.role == Role.system.value: - color = "green" - else: - color = "white" - - s = str(m) - cprint(f"{i} ::: {s[:100]}...", color=color) diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 2565f1994..d51e25a32 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import json - +import logging import uuid from datetime import datetime @@ -15,6 +15,8 @@ from pydantic import BaseModel from llama_stack.providers.utils.kvstore import KVStore +log = logging.getLogger(__name__) + class AgentSessionInfo(BaseModel): session_id: str @@ -78,7 +80,7 @@ class AgentPersistence: turn = Turn(**json.loads(value)) turns.append(turn) except Exception as e: - print(f"Error parsing turn: {e}") + log.error(f"Error parsing turn: {e}") continue turns.sort(key=lambda x: (x.completed_at or datetime.min)) return turns diff --git a/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py b/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py index b668dc0d6..08e778439 100644 --- a/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py +++ b/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py @@ -10,8 +10,6 @@ from jinja2 import Template from llama_models.llama3.api import * # noqa: F403 -from termcolor import cprint # noqa: F401 - from llama_stack.apis.agents import ( DefaultMemoryQueryGeneratorConfig, LLMMemoryQueryGeneratorConfig, @@ -36,7 +34,6 @@ async def generate_rag_query( query = await llm_rag_query_generator(config, messages, **kwargs) else: raise NotImplementedError(f"Unsupported memory query generator {config.type}") - # cprint(f"Generated query >>>: {query}", color="green") return query diff --git a/llama_stack/providers/inline/agents/meta_reference/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py index 77525e871..3eca94fc5 100644 --- a/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -5,14 +5,16 @@ # the root directory of this source tree. import asyncio +import logging from typing import List from llama_models.llama3.api.datatypes import Message -from termcolor import cprint from llama_stack.apis.safety import * # noqa: F403 +log = logging.getLogger(__name__) + class SafetyException(Exception): # noqa: N818 def __init__(self, violation: SafetyViolation): @@ -51,7 +53,4 @@ class ShieldRunnerMixin: if violation.violation_level == ViolationLevel.ERROR: raise SafetyException(violation) elif violation.violation_level == ViolationLevel.WARN: - cprint( - f"[Warn]{identifier} raised a warning", - color="red", - ) + log.warning(f"[Warn]{identifier} raised a warning") diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/builtin.py b/llama_stack/providers/inline/agents/meta_reference/tools/builtin.py index a1e7d08f5..0bbf67ed8 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tools/builtin.py +++ b/llama_stack/providers/inline/agents/meta_reference/tools/builtin.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import json +import logging import re import tempfile @@ -12,7 +13,6 @@ from abc import abstractmethod from typing import List, Optional import requests -from termcolor import cprint from .ipython_tool.code_execution import ( CodeExecutionContext, @@ -27,6 +27,9 @@ from llama_stack.apis.agents import * # noqa: F403 from .base import BaseTool +log = logging.getLogger(__name__) + + def interpret_content_as_attachment(content: str) -> Optional[Attachment]: match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content) if match: @@ -383,7 +386,7 @@ class CodeInterpreterTool(BaseTool): if res_out != "": pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"]) if out_type == "stderr": - cprint(f"ipython tool error: ↓\n{res_out}", color="red") + log.error(f"ipython tool error: ↓\n{res_out}") message = ToolResponseMessage( call_id=tool_call.call_id, diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py index 3aba2ef21..7fec08cf2 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py +++ b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py @@ -11,6 +11,7 @@ A custom Matplotlib backend that overrides the show method to return image bytes import base64 import io import json as _json +import logging import matplotlib from matplotlib.backend_bases import FigureManagerBase @@ -18,6 +19,8 @@ from matplotlib.backend_bases import FigureManagerBase # Import necessary components from Matplotlib from matplotlib.backends.backend_agg import FigureCanvasAgg +log = logging.getLogger(__name__) + class CustomFigureCanvas(FigureCanvasAgg): def show(self): @@ -80,7 +83,7 @@ def show(): ) req_con.send_bytes(_json_dump.encode("utf-8")) resp = _json.loads(resp_con.recv_bytes().decode("utf-8")) - print(resp) + log.info(resp) FigureCanvas = CustomFigureCanvas diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 577f5184b..080e33be0 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -8,6 +8,7 @@ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. import json +import logging import math import os import sys @@ -31,7 +32,6 @@ from llama_models.llama3.reference_impl.multimodal.model import ( ) from llama_models.sku_list import resolve_model from pydantic import BaseModel -from termcolor import cprint from llama_stack.apis.inference import * # noqa: F403 @@ -50,6 +50,8 @@ from .config import ( MetaReferenceQuantizedInferenceConfig, ) +log = logging.getLogger(__name__) + def model_checkpoint_dir(model) -> str: checkpoint_dir = Path(model_local_dir(model.descriptor())) @@ -185,7 +187,7 @@ class Llama: model = Transformer(model_args) model.load_state_dict(state_dict, strict=False) - print(f"Loaded in {time.time() - start_time:.2f} seconds") + log.info(f"Loaded in {time.time() - start_time:.2f} seconds") return Llama(model, tokenizer, model_args, llama_model) def __init__( @@ -221,7 +223,7 @@ class Llama: self.formatter.vision_token if t == 128256 else t for t in model_input.tokens ] - cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red") + log.info("Input to model -> " + self.tokenizer.decode(input_tokens)) prompt_tokens = [model_input.tokens] bsz = 1 @@ -231,9 +233,7 @@ class Llama: max_prompt_len = max(len(t) for t in prompt_tokens) if max_prompt_len >= params.max_seq_len: - cprint( - f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red" - ) + log.error(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}") return total_len = min(max_gen_len + max_prompt_len, params.max_seq_len) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index e6bcd6730..07fd4af44 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import asyncio +import logging from typing import AsyncGenerator, List @@ -25,6 +26,7 @@ from .config import MetaReferenceInferenceConfig from .generation import Llama from .model_parallel import LlamaModelParallelGenerator +log = logging.getLogger(__name__) # there's a single model parallel process running serving the model. for now, # we don't support multiple concurrent requests to this process. SEMAPHORE = asyncio.Semaphore(1) @@ -49,7 +51,7 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP # verify that the checkpoint actually is for this model lol async def initialize(self) -> None: - print(f"Loading model `{self.model.descriptor()}`") + log.info(f"Loading model `{self.model.descriptor()}`") if self.config.create_distributed_process_group: self.generator = LlamaModelParallelGenerator(self.config) self.generator.start() diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 62eeefaac..076e39729 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -11,6 +11,7 @@ # the root directory of this source tree. import json +import logging import multiprocessing import os import tempfile @@ -37,6 +38,8 @@ from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest from .generation import TokenResult +log = logging.getLogger(__name__) + class ProcessingMessageName(str, Enum): ready_request = "ready_request" @@ -183,16 +186,16 @@ def retrieve_requests(reply_socket_url: str): group=get_model_parallel_group(), ) if isinstance(updates[0], CancelSentinel): - print("quitting generation loop because request was cancelled") + log.info( + "quitting generation loop because request was cancelled" + ) break if mp_rank_0(): send_obj(EndSentinel()) except Exception as e: - print(f"[debug] got exception {e}") - import traceback + log.exception("exception in generation loop") - traceback.print_exc() if mp_rank_0(): send_obj(ExceptionResponse(error=str(e))) @@ -252,7 +255,7 @@ def worker_process_entrypoint( except StopIteration: break - print("[debug] worker process done") + log.info("[debug] worker process done") def launch_dist_group( @@ -313,7 +316,7 @@ def start_model_parallel_process( request_socket.send(encode_msg(ReadyRequest())) response = request_socket.recv() - print("Loaded model...") + log.info("Loaded model...") return request_socket, process @@ -361,7 +364,7 @@ class ModelParallelProcessGroup: break if isinstance(obj, ExceptionResponse): - print(f"[debug] got exception {obj.error}") + log.error(f"[debug] got exception {obj.error}") raise Exception(obj.error) if isinstance(obj, TaskResponse): diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls.py b/llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls.py index 98cf2a9a1..92c447707 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls.py @@ -8,14 +8,20 @@ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. import collections + +import logging from typing import Optional, Type +log = logging.getLogger(__name__) + try: import fbgemm_gpu.experimental.gen_ai # noqa: F401 - print("Using efficient FP8 operators in FBGEMM.") + log.info("Using efficient FP8 operators in FBGEMM.") except ImportError: - print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.") + log.error( + "No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt." + ) raise import torch diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py index 3eaac1e71..80d47b054 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py @@ -7,6 +7,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. +import logging import os from typing import Any, Dict, List, Optional @@ -21,7 +22,6 @@ from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from llama_models.sku_list import resolve_model -from termcolor import cprint from torch import nn, Tensor from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear @@ -30,6 +30,8 @@ from llama_stack.apis.inference import QuantizationType from ..config import MetaReferenceQuantizedInferenceConfig +log = logging.getLogger(__name__) + def swiglu_wrapper( self, @@ -60,7 +62,7 @@ def convert_to_fp8_quantized_model( # Move weights to GPU with quantization if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value: - cprint("Loading fp8 scales...", "yellow") + log.info("Loading fp8 scales...") fp8_scales_path = os.path.join( checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt" ) @@ -85,7 +87,7 @@ def convert_to_fp8_quantized_model( fp8_activation_scale_ub, ) else: - cprint("Quantizing fp8 weights from bf16...", "yellow") + log.info("Quantizing fp8 weights from bf16...") for block in model.layers: if isinstance(block, TransformerBlock): if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py index 891a06296..b282d976f 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py @@ -8,6 +8,7 @@ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. import json +import logging import os import shutil import sys @@ -32,6 +33,8 @@ from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impl quantize_fp8, ) +log = logging.getLogger(__name__) + def main( ckpt_dir: str, @@ -102,7 +105,7 @@ def main( else: torch.set_default_tensor_type(torch.cuda.HalfTensor) - print(ckpt_path) + log.info(ckpt_path) assert ( quantized_ckpt_dir is not None ), "QUantized checkpoint directory should not be None" diff --git a/llama_stack/providers/inline/meta_reference/telemetry/config.py b/llama_stack/providers/inline/meta_reference/telemetry/config.py index c639c6798..34d5bc08e 100644 --- a/llama_stack/providers/inline/meta_reference/telemetry/config.py +++ b/llama_stack/providers/inline/meta_reference/telemetry/config.py @@ -4,10 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from enum import Enum + from llama_models.schema_utils import json_schema_type from pydantic import BaseModel +class LogFormat(Enum): + TEXT = "text" + JSON = "json" + + @json_schema_type -class ConsoleConfig(BaseModel): ... +class ConsoleConfig(BaseModel): + log_format: LogFormat = LogFormat.JSON diff --git a/llama_stack/providers/inline/meta_reference/telemetry/console.py b/llama_stack/providers/inline/meta_reference/telemetry/console.py index b56c704a6..d8ef49481 100644 --- a/llama_stack/providers/inline/meta_reference/telemetry/console.py +++ b/llama_stack/providers/inline/meta_reference/telemetry/console.py @@ -4,8 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json from typing import Optional +from .config import LogFormat + from llama_stack.apis.telemetry import * # noqa: F403 from .config import ConsoleConfig @@ -38,7 +41,11 @@ class ConsoleTelemetryImpl(Telemetry): span_name = ".".join(names) if names else None - formatted = format_event(event, span_name) + if self.config.log_format == LogFormat.JSON: + formatted = format_event_json(event, span_name) + else: + formatted = format_event_text(event, span_name) + if formatted: print(formatted) @@ -69,7 +76,7 @@ SEVERITY_COLORS = { } -def format_event(event: Event, span_name: str) -> Optional[str]: +def format_event_text(event: Event, span_name: str) -> Optional[str]: timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3] span = "" if span_name: @@ -87,3 +94,23 @@ def format_event(event: Event, span_name: str) -> Optional[str]: return None return f"Unknown event type: {event}" + + +def format_event_json(event: Event, span_name: str) -> Optional[str]: + base_data = { + "timestamp": event.timestamp.isoformat(), + "trace_id": event.trace_id, + "span_id": event.span_id, + "span_name": span_name, + } + + if isinstance(event, UnstructuredLogEvent): + base_data.update( + {"type": "log", "severity": event.severity.name, "message": event.message} + ) + return json.dumps(base_data) + + elif isinstance(event, StructuredLogEvent): + return None + + return json.dumps({"error": f"Unknown event type: {event}"}) diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index c477c685c..54a4d0b18 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -4,16 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging from typing import Any, Dict, List from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message -from termcolor import cprint from .config import CodeScannerConfig from llama_stack.apis.safety import * # noqa: F403 - +log = logging.getLogger(__name__) ALLOWED_CODE_SCANNER_MODEL_IDS = [ "CodeScanner", "CodeShield", @@ -49,7 +49,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): from codeshield.cs import CodeShield text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages]) - cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta") + log.info(f"Running CodeScannerShield on {text[50:]}") result = await CodeShield.scan_code(text) violation = None diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 9f3d78374..e2deb3df7 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -4,10 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging from typing import Any, Dict, List import torch -from termcolor import cprint from transformers import AutoModelForSequenceClassification, AutoTokenizer @@ -20,6 +20,7 @@ from llama_stack.providers.datatypes import ShieldsProtocolPrivate from .config import PromptGuardConfig, PromptGuardType +log = logging.getLogger(__name__) PROMPT_GUARD_MODEL = "Prompt-Guard-86M" @@ -93,9 +94,8 @@ class PromptGuardShield: probabilities = torch.softmax(logits / self.temperature, dim=-1) score_embedded = probabilities[0, 1].item() score_malicious = probabilities[0, 2].item() - cprint( + log.info( f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}", - color="magenta", ) violation = None diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index f53ed4e14..56287fd65 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging from typing import AsyncGenerator import httpx @@ -39,6 +40,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( request_has_media, ) +log = logging.getLogger(__name__) model_aliases = [ build_model_alias( @@ -105,7 +107,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return AsyncClient(host=self.url) async def initialize(self) -> None: - print(f"checking connectivity to Ollama at `{self.url}`...") + log.info(f"checking connectivity to Ollama at `{self.url}`...") try: await self.client.ps() except httpx.ConnectError as e: diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 92492e3da..d57fbdc17 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -34,7 +34,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig -logger = logging.getLogger(__name__) +log = logging.getLogger(__name__) class _HfAdapter(Inference, ModelsProtocolPrivate): @@ -264,7 +264,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): class TGIAdapter(_HfAdapter): async def initialize(self, config: TGIImplConfig) -> None: - print(f"Initializing TGI client with url={config.url}") + log.info(f"Initializing TGI client with url={config.url}") self.client = AsyncInferenceClient(model=config.url, token=config.api_token) endpoint_info = await self.client.get_endpoint_info() self.max_tokens = endpoint_info["max_total_tokens"] diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 3c877639c..0f4034478 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -3,6 +3,8 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +import logging from typing import AsyncGenerator from llama_models.llama3.api.chat_format import ChatFormat @@ -34,6 +36,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import VLLMInferenceAdapterConfig +log = logging.getLogger(__name__) + + def build_model_aliases(): return [ build_model_alias( @@ -53,7 +58,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): self.client = None async def initialize(self) -> None: - print(f"Initializing VLLM client with base_url={self.config.url}") + log.info(f"Initializing VLLM client with base_url={self.config.url}") self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) async def shutdown(self) -> None: diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index 3ccd6a534..20185aade 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import json +import logging from typing import List from urllib.parse import urlparse @@ -21,6 +22,8 @@ from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, ) +log = logging.getLogger(__name__) + class ChromaIndex(EmbeddingIndex): def __init__(self, client: chromadb.AsyncHttpClient, collection): @@ -56,10 +59,7 @@ class ChromaIndex(EmbeddingIndex): doc = json.loads(doc) chunk = Chunk(**doc) except Exception: - import traceback - - traceback.print_exc() - print(f"Failed to parse document: {doc}") + log.exception(f"Failed to parse document: {doc}") continue chunks.append(chunk) @@ -73,7 +73,7 @@ class ChromaIndex(EmbeddingIndex): class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): def __init__(self, url: str) -> None: - print(f"Initializing ChromaMemoryAdapter with url: {url}") + log.info(f"Initializing ChromaMemoryAdapter with url: {url}") url = url.rstrip("/") parsed = urlparse(url) @@ -88,12 +88,10 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): async def initialize(self) -> None: try: - print(f"Connecting to Chroma server at: {self.host}:{self.port}") + log.info(f"Connecting to Chroma server at: {self.host}:{self.port}") self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port) except Exception as e: - import traceback - - traceback.print_exc() + log.exception("Could not connect to Chroma server") raise RuntimeError("Could not connect to Chroma server") from e async def shutdown(self) -> None: @@ -123,10 +121,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): data = json.loads(collection.metadata["bank"]) bank = parse_obj_as(VectorMemoryBank, data) except Exception: - import traceback - - traceback.print_exc() - print(f"Failed to parse bank: {collection.metadata}") + log.exception(f"Failed to parse bank: {collection.metadata}") continue index = BankWithIndex( diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index bd27509d6..d77de7b41 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging from typing import List, Tuple import psycopg2 @@ -24,6 +25,8 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import PGVectorConfig +log = logging.getLogger(__name__) + def check_extension_version(cur): cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'") @@ -124,7 +127,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): self.cache = {} async def initialize(self) -> None: - print(f"Initializing PGVector memory adapter with config: {self.config}") + log.info(f"Initializing PGVector memory adapter with config: {self.config}") try: self.conn = psycopg2.connect( host=self.config.host, @@ -138,7 +141,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): version = check_extension_version(self.cursor) if version: - print(f"Vector extension version: {version}") + log.info(f"Vector extension version: {version}") else: raise RuntimeError("Vector extension is not installed.") @@ -151,9 +154,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): """ ) except Exception as e: - import traceback - - traceback.print_exc() + log.exception("Could not connect to PGVector database server") raise RuntimeError("Could not connect to PGVector database server") from e async def shutdown(self) -> None: diff --git a/llama_stack/providers/remote/memory/qdrant/qdrant.py b/llama_stack/providers/remote/memory/qdrant/qdrant.py index 27923a7c5..be370eec9 100644 --- a/llama_stack/providers/remote/memory/qdrant/qdrant.py +++ b/llama_stack/providers/remote/memory/qdrant/qdrant.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import traceback +import logging import uuid from typing import Any, Dict, List @@ -23,6 +23,7 @@ from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, ) +log = logging.getLogger(__name__) CHUNK_ID_KEY = "_chunk_id" @@ -90,7 +91,7 @@ class QdrantIndex(EmbeddingIndex): try: chunk = Chunk(**point.payload["chunk_content"]) except Exception: - traceback.print_exc() + log.exception("Failed to parse chunk") continue chunks.append(chunk) diff --git a/llama_stack/providers/remote/memory/weaviate/weaviate.py b/llama_stack/providers/remote/memory/weaviate/weaviate.py index 2844402b5..f8fba5c0b 100644 --- a/llama_stack/providers/remote/memory/weaviate/weaviate.py +++ b/llama_stack/providers/remote/memory/weaviate/weaviate.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import json +import logging from typing import Any, Dict, List, Optional @@ -22,6 +23,8 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import WeaviateConfig, WeaviateRequestProviderData +log = logging.getLogger(__name__) + class WeaviateIndex(EmbeddingIndex): def __init__(self, client: weaviate.Client, collection_name: str): @@ -69,10 +72,7 @@ class WeaviateIndex(EmbeddingIndex): chunk_dict = json.loads(chunk_json) chunk = Chunk(**chunk_dict) except Exception: - import traceback - - traceback.print_exc() - print(f"Failed to parse document: {chunk_json}") + log.exception(f"Failed to parse document: {chunk_json}") continue chunks.append(chunk) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 6e4d0752e..ca06e1b1f 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -7,14 +7,13 @@ import base64 import io import json +import logging from typing import Tuple import httpx from llama_models.llama3.api.chat_format import ChatFormat from PIL import Image as PIL_Image -from termcolor import cprint - from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_models.datatypes import ModelFamily @@ -29,6 +28,8 @@ from llama_models.sku_list import resolve_model from llama_stack.providers.utils.inference import supported_inference_models +log = logging.getLogger(__name__) + def content_has_media(content: InterleavedTextMedia): def _has_media_content(c): @@ -175,13 +176,13 @@ def chat_completion_request_to_messages( """ model = resolve_model(llama_model) if model is None: - cprint(f"Could not resolve model {llama_model}", color="red") + log.error(f"Could not resolve model {llama_model}") return request.messages allowed_models = supported_inference_models() descriptors = [m.descriptor() for m in allowed_models] if model.descriptor() not in descriptors: - cprint(f"Unsupported inference model? {model.descriptor()}", color="red") + log.error(f"Unsupported inference model? {model.descriptor()}") return request.messages if model.model_family == ModelFamily.llama3_1 or ( diff --git a/llama_stack/providers/utils/kvstore/postgres/postgres.py b/llama_stack/providers/utils/kvstore/postgres/postgres.py index 23ceb58e4..20428f285 100644 --- a/llama_stack/providers/utils/kvstore/postgres/postgres.py +++ b/llama_stack/providers/utils/kvstore/postgres/postgres.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging from datetime import datetime from typing import List, Optional @@ -13,6 +14,8 @@ from psycopg2.extras import DictCursor from ..api import KVStore from ..config import PostgresKVStoreConfig +log = logging.getLogger(__name__) + class PostgresKVStoreImpl(KVStore): def __init__(self, config: PostgresKVStoreConfig): @@ -43,9 +46,8 @@ class PostgresKVStoreImpl(KVStore): """ ) except Exception as e: - import traceback - traceback.print_exc() + log.exception("Could not connect to PostgreSQL database server") raise RuntimeError("Could not connect to PostgreSQL database server") from e def _namespaced_key(self, key: str) -> str: diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 2bbf6cdd2..48cb8a99d 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import base64 import io +import logging import re from abc import ABC, abstractmethod from dataclasses import dataclass @@ -16,13 +17,14 @@ import httpx import numpy as np from numpy.typing import NDArray from pypdf import PdfReader -from termcolor import cprint from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.memory import * # noqa: F403 +log = logging.getLogger(__name__) + ALL_MINILM_L6_V2_DIMENSION = 384 EMBEDDING_MODELS = {} @@ -35,7 +37,7 @@ def get_embedding_model(model: str) -> "SentenceTransformer": if loaded_model is not None: return loaded_model - print(f"Loading sentence transformer for {model}...") + log.info(f"Loading sentence transformer for {model}...") from sentence_transformers import SentenceTransformer loaded_model = SentenceTransformer(model) @@ -92,7 +94,7 @@ def content_from_data(data_url: str) -> str: return "\n".join([page.extract_text() for page in pdf_reader.pages]) else: - cprint("Could not extract content from data_url properly.", color="red") + log.error("Could not extract content from data_url properly.") return "" diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 207064904..3383f7a7a 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -17,6 +17,8 @@ from typing import Any, Callable, Dict, List from llama_stack.apis.telemetry import * # noqa: F403 +log = logging.getLogger(__name__) + def generate_short_uuid(len: int = 12): full_uuid = uuid.uuid4() @@ -40,7 +42,7 @@ class BackgroundLogger: try: self.log_queue.put_nowait(event) except queue.Full: - print("Log queue is full, dropping event") + log.error("Log queue is full, dropping event") def _process_logs(self): while True: @@ -125,7 +127,7 @@ async def start_trace(name: str, attributes: Dict[str, Any] = None): global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER if BACKGROUND_LOGGER is None: - print("No Telemetry implementation set. Skipping trace initialization...") + log.info("No Telemetry implementation set. Skipping trace initialization...") return trace_id = generate_short_uuid()