use logging instead of prints (#499)

# What does this PR do?

This PR moves all print statements to use logging. Things changed:
- Had to add `await start_trace("sse_generator")` to server.py to
actually get tracing working. else was not seeing any logs
- If no telemetry provider is provided in the run.yaml, we will write to
stdout
- by default, the logs are going to be in JSON, but we expose an option
to configure to output in a human readable way.
This commit is contained in:
Dinesh Yeduguru 2024-11-21 11:32:53 -08:00 committed by GitHub
parent 4e1105e563
commit 6395dadc2b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 234 additions and 163 deletions

View file

@ -4,14 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from enum import Enum
from typing import List
import pkg_resources
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.distribution.utils.exec import run_with_pty
from llama_stack.distribution.datatypes import * # noqa: F403
@ -22,6 +21,8 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
log = logging.getLogger(__name__)
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
@ -89,12 +90,12 @@ def get_provider_dependencies(
def print_pip_install_help(providers: Dict[str, List[Provider]]):
normal_deps, special_deps = get_provider_dependencies(providers)
print(
log.info(
f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}"
)
for special_dep in special_deps:
print(f"\tpip install {special_dep}")
print()
log.info(f"\tpip install {special_dep}")
log.info()
def build_image(build_config: BuildConfig, build_file_path: Path):
@ -133,9 +134,8 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
return_code = run_with_pty(args)
if return_code != 0:
cprint(
log.error(
f"Failed to build target {build_config.name} with return code {return_code}",
color="red",
)
return return_code

View file

@ -3,12 +3,12 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import textwrap
from typing import Any
from llama_stack.distribution.datatypes import * # noqa: F403
from termcolor import cprint
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
@ -22,6 +22,8 @@ from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
logger = logging.getLogger(__name__)
def configure_single_provider(
registry: Dict[str, ProviderSpec], provider: Provider
@ -50,7 +52,7 @@ def configure_api_providers(
is_nux = len(config.providers) == 0
if is_nux:
print(
logger.info(
textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. For each API served by the Stack,
@ -76,18 +78,18 @@ def configure_api_providers(
existing_providers = config.providers.get(api_str, [])
if existing_providers:
cprint(
logger.info(
f"Re-configuring existing providers for API `{api_str}`...",
"green",
attrs=["bold"],
)
updated_providers = []
for p in existing_providers:
print(f"> Configuring provider `({p.provider_type})`")
logger.info(f"> Configuring provider `({p.provider_type})`")
updated_providers.append(
configure_single_provider(provider_registry[api], p)
)
print("")
logger.info("")
else:
# we are newly configuring this API
plist = build_spec.providers.get(api_str, [])
@ -96,17 +98,17 @@ def configure_api_providers(
if not plist:
raise ValueError(f"No provider configured for API {api_str}?")
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
logger.info(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
updated_providers = []
for i, provider_type in enumerate(plist):
if i >= 1:
others = ", ".join(plist[i:])
print(
logger.info(
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
)
break
print(f"> Configuring provider `({provider_type})`")
logger.info(f"> Configuring provider `({provider_type})`")
updated_providers.append(
configure_single_provider(
provider_registry[api],
@ -121,7 +123,7 @@ def configure_api_providers(
),
)
)
print("")
logger.info("")
config.providers[api_str] = updated_providers
@ -182,7 +184,7 @@ def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfi
return StackRunConfig(**config_dict)
if "routing_table" in config_dict:
print("Upgrading config...")
logger.info("Upgrading config...")
config_dict = upgrade_from_routing_table(config_dict)
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION

View file

@ -5,11 +5,14 @@
# the root directory of this source tree.
import json
import logging
import threading
from typing import Any, Dict
from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
_THREAD_LOCAL = threading.local()
@ -32,7 +35,7 @@ class NeedsRequestProviderData:
provider_data = validator(**val)
return provider_data
except Exception as e:
print("Error parsing provider data", e)
log.error("Error parsing provider data", e)
def set_request_provider_data(headers: Dict[str, str]):
@ -51,7 +54,7 @@ def set_request_provider_data(headers: Dict[str, str]):
try:
val = json.loads(val)
except json.JSONDecodeError:
print("Provider data not encoded as a JSON object!", val)
log.error("Provider data not encoded as a JSON object!", val)
return
_THREAD_LOCAL.provider_data_header_value = val

View file

@ -8,11 +8,12 @@ import inspect
from typing import Any, Dict, List, Set
from termcolor import cprint
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
import logging
from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
@ -33,6 +34,8 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_a
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
class InvalidProviderError(Exception):
pass
@ -115,11 +118,11 @@ async def resolve_impls(
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
cprint(p.deprecation_error, "red", attrs=["bold"])
log.error(p.deprecation_error, "red", attrs=["bold"])
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
cprint(
log.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
"yellow",
attrs=["bold"],
@ -199,10 +202,10 @@ async def resolve_impls(
)
)
print(f"Resolved {len(sorted_providers)} providers")
log.info(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
print(f" {api_str} => {provider.provider_id}")
print("")
log.info(f" {api_str} => {provider.provider_id}")
log.info("")
impls = {}
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
@ -339,7 +342,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
print(
log.error(
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
)
missing_methods.append((name, "signature_mismatch"))

View file

@ -46,6 +46,10 @@ from llama_stack.distribution.stack import (
replace_env_vars,
validate_env_pair,
)
from llama_stack.providers.inline.meta_reference.telemetry.console import (
ConsoleConfig,
ConsoleTelemetryImpl,
)
from .endpoints import get_all_api_endpoints
@ -196,7 +200,6 @@ def handle_sigint(app, *args, **kwargs):
async def lifespan(app: FastAPI):
print("Starting up")
yield
print("Shutting down")
for impl in app.__llama_stack_impls__.values():
await impl.shutdown()
@ -214,6 +217,7 @@ async def maybe_await(value):
async def sse_generator(event_gen):
await start_trace("sse_generator")
try:
event_gen = await event_gen
async for item in event_gen:
@ -333,7 +337,7 @@ def main():
print("Run configuration:")
print(yaml.dump(config.model_dump(), indent=2))
app = FastAPI()
app = FastAPI(lifespan=lifespan)
try:
impls = asyncio.run(construct_stack(config))
@ -342,6 +346,8 @@ def main():
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(ConsoleTelemetryImpl(ConsoleConfig()))
all_endpoints = get_all_api_endpoints()

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
from pathlib import Path
from typing import Any, Dict
@ -40,6 +41,8 @@ from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
LLAMA_STACK_API_VERSION = "alpha"
@ -93,11 +96,11 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
method = getattr(impls[api], list_method)
for obj in await method():
print(
log.info(
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
)
print("")
log.info("")
class EnvVarError(Exception):

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import errno
import logging
import os
import pty
import select
@ -13,7 +14,7 @@ import subprocess
import sys
import termios
from termcolor import cprint
log = logging.getLogger(__name__)
# run a command in a pseudo-terminal, with interrupt handling,
@ -29,7 +30,7 @@ def run_with_pty(command):
def sigint_handler(signum, frame):
nonlocal ctrl_c_pressed
ctrl_c_pressed = True
cprint("\nCtrl-C detected. Aborting...", "white", attrs=["bold"])
log.info("\nCtrl-C detected. Aborting...")
try:
# Set up the signal handler
@ -100,6 +101,6 @@ def run_command(command):
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, error = process.communicate()
if process.returncode != 0:
print(f"Error: {error.decode('utf-8')}")
log.error(f"Error: {error.decode('utf-8')}")
sys.exit(1)
return output.decode("utf-8")

View file

@ -6,6 +6,7 @@
import inspect
import json
import logging
from enum import Enum
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
@ -16,6 +17,8 @@ from pydantic_core import PydanticUndefinedType
from typing_extensions import Annotated
log = logging.getLogger(__name__)
def is_list_of_primitives(field_type):
"""Check if a field type is a List of primitive types."""
@ -111,7 +114,7 @@ def prompt_for_discriminated_union(
if discriminator_value in type_map:
chosen_type = type_map[discriminator_value]
print(f"\nConfiguring {chosen_type.__name__}:")
log.info(f"\nConfiguring {chosen_type.__name__}:")
if existing_value and (
getattr(existing_value, discriminator) != discriminator_value
@ -123,7 +126,7 @@ def prompt_for_discriminated_union(
setattr(sub_config, discriminator, discriminator_value)
return sub_config
else:
print(f"Invalid {discriminator}. Please try again.")
log.error(f"Invalid {discriminator}. Please try again.")
# This is somewhat elaborate, but does not purport to be comprehensive in any way.
@ -180,7 +183,7 @@ def prompt_for_config(
config_data[field_name] = validated_value
break
except KeyError:
print(
log.error(
f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}"
)
continue
@ -197,7 +200,7 @@ def prompt_for_config(
config_data[field_name] = None
continue
nested_type = get_non_none_type(field_type)
print(f"Entering sub-configuration for {field_name}:")
log.info(f"Entering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(nested_type, existing_value)
elif is_optional(field_type) and is_discriminated_union(
get_non_none_type(field_type)
@ -213,7 +216,7 @@ def prompt_for_config(
existing_value,
)
elif can_recurse(field_type):
print(f"\nEntering sub-configuration for {field_name}:")
log.info(f"\nEntering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(
field_type,
existing_value,
@ -240,7 +243,7 @@ def prompt_for_config(
config_data[field_name] = None
break
else:
print("This field is required. Please provide a value.")
log.error("This field is required. Please provide a value.")
continue
else:
try:
@ -264,12 +267,12 @@ def prompt_for_config(
value = [element_type(item) for item in value]
except json.JSONDecodeError:
print(
log.error(
'Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]'
)
continue
except ValueError as e:
print(f"{str(e)}")
log.error(f"{str(e)}")
continue
elif get_origin(field_type) is dict:
@ -281,7 +284,7 @@ def prompt_for_config(
)
except json.JSONDecodeError:
print(
log.error(
"Invalid JSON. Please enter a valid JSON-encoded dict."
)
continue
@ -298,7 +301,7 @@ def prompt_for_config(
value = field_type(user_input)
except ValueError:
print(
log.error(
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
)
continue
@ -311,6 +314,6 @@ def prompt_for_config(
config_data[field_name] = validated_value
break
except ValueError as e:
print(f"Validation error: {str(e)}")
log.error(f"Validation error: {str(e)}")
return config_type(**config_data)