use logging instead of prints (#499)

# What does this PR do?

This PR moves all print statements to use logging. Things changed:
- Had to add `await start_trace("sse_generator")` to server.py to
actually get tracing working. else was not seeing any logs
- If no telemetry provider is provided in the run.yaml, we will write to
stdout
- by default, the logs are going to be in JSON, but we expose an option
to configure to output in a human readable way.
This commit is contained in:
Dinesh Yeduguru 2024-11-21 11:32:53 -08:00 committed by GitHub
parent 4e1105e563
commit 6395dadc2b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 234 additions and 163 deletions

View file

@ -14,15 +14,19 @@ import httpx
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RemoteProviderConfig from llama_stack.distribution.datatypes import RemoteProviderConfig
from .agents import * # noqa: F403 from .agents import * # noqa: F403
import logging
from .event_logger import EventLogger from .event_logger import EventLogger
log = logging.getLogger(__name__)
load_dotenv() load_dotenv()
@ -93,13 +97,12 @@ class AgentsClient(Agents):
try: try:
jdata = json.loads(data) jdata = json.loads(data)
if "error" in jdata: if "error" in jdata:
cprint(data, "red") log.error(data)
continue continue
yield AgentTurnResponseStreamChunk(**jdata) yield AgentTurnResponseStreamChunk(**jdata)
except Exception as e: except Exception as e:
print(data) log.error(f"Error with parsing or validation: {e}")
print(f"Error with parsing or validation: {e}")
async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest): async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest):
raise NotImplementedError("Non-streaming not implemented yet") raise NotImplementedError("Non-streaming not implemented yet")
@ -125,7 +128,7 @@ async def _run_agent(
) )
for content in user_prompts: for content in user_prompts:
cprint(f"User> {content}", color="white", attrs=["bold"]) log.info(f"User> {content}", color="white", attrs=["bold"])
iterator = await api.create_agent_turn( iterator = await api.create_agent_turn(
AgentTurnCreateRequest( AgentTurnCreateRequest(
agent_id=create_response.agent_id, agent_id=create_response.agent_id,
@ -138,9 +141,9 @@ async def _run_agent(
) )
) )
async for event, log in EventLogger().log(iterator): async for event, logger in EventLogger().log(iterator):
if log is not None: if logger is not None:
log.print() log.info(logger)
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"): async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):

View file

@ -4,14 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from enum import Enum from enum import Enum
from typing import List from typing import List
import pkg_resources import pkg_resources
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint
from llama_stack.distribution.utils.exec import run_with_pty from llama_stack.distribution.utils.exec import run_with_pty
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
@ -22,6 +21,8 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
log = logging.getLogger(__name__)
# These are the dependencies needed by the distribution server. # These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script. # `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [ SERVER_DEPENDENCIES = [
@ -89,12 +90,12 @@ def get_provider_dependencies(
def print_pip_install_help(providers: Dict[str, List[Provider]]): def print_pip_install_help(providers: Dict[str, List[Provider]]):
normal_deps, special_deps = get_provider_dependencies(providers) normal_deps, special_deps = get_provider_dependencies(providers)
print( log.info(
f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}" f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}"
) )
for special_dep in special_deps: for special_dep in special_deps:
print(f"\tpip install {special_dep}") log.info(f"\tpip install {special_dep}")
print() log.info()
def build_image(build_config: BuildConfig, build_file_path: Path): def build_image(build_config: BuildConfig, build_file_path: Path):
@ -133,9 +134,8 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
return_code = run_with_pty(args) return_code = run_with_pty(args)
if return_code != 0: if return_code != 0:
cprint( log.error(
f"Failed to build target {build_config.name} with return code {return_code}", f"Failed to build target {build_config.name} with return code {return_code}",
color="red",
) )
return return_code return return_code

View file

@ -3,12 +3,12 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import textwrap import textwrap
from typing import Any from typing import Any
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from termcolor import cprint
from llama_stack.distribution.distribution import ( from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis, builtin_automatically_routed_apis,
@ -22,6 +22,8 @@ from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403
logger = logging.getLogger(__name__)
def configure_single_provider( def configure_single_provider(
registry: Dict[str, ProviderSpec], provider: Provider registry: Dict[str, ProviderSpec], provider: Provider
@ -50,7 +52,7 @@ def configure_api_providers(
is_nux = len(config.providers) == 0 is_nux = len(config.providers) == 0
if is_nux: if is_nux:
print( logger.info(
textwrap.dedent( textwrap.dedent(
""" """
Llama Stack is composed of several APIs working together. For each API served by the Stack, Llama Stack is composed of several APIs working together. For each API served by the Stack,
@ -76,18 +78,18 @@ def configure_api_providers(
existing_providers = config.providers.get(api_str, []) existing_providers = config.providers.get(api_str, [])
if existing_providers: if existing_providers:
cprint( logger.info(
f"Re-configuring existing providers for API `{api_str}`...", f"Re-configuring existing providers for API `{api_str}`...",
"green", "green",
attrs=["bold"], attrs=["bold"],
) )
updated_providers = [] updated_providers = []
for p in existing_providers: for p in existing_providers:
print(f"> Configuring provider `({p.provider_type})`") logger.info(f"> Configuring provider `({p.provider_type})`")
updated_providers.append( updated_providers.append(
configure_single_provider(provider_registry[api], p) configure_single_provider(provider_registry[api], p)
) )
print("") logger.info("")
else: else:
# we are newly configuring this API # we are newly configuring this API
plist = build_spec.providers.get(api_str, []) plist = build_spec.providers.get(api_str, [])
@ -96,17 +98,17 @@ def configure_api_providers(
if not plist: if not plist:
raise ValueError(f"No provider configured for API {api_str}?") raise ValueError(f"No provider configured for API {api_str}?")
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"]) logger.info(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
updated_providers = [] updated_providers = []
for i, provider_type in enumerate(plist): for i, provider_type in enumerate(plist):
if i >= 1: if i >= 1:
others = ", ".join(plist[i:]) others = ", ".join(plist[i:])
print( logger.info(
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n" f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
) )
break break
print(f"> Configuring provider `({provider_type})`") logger.info(f"> Configuring provider `({provider_type})`")
updated_providers.append( updated_providers.append(
configure_single_provider( configure_single_provider(
provider_registry[api], provider_registry[api],
@ -121,7 +123,7 @@ def configure_api_providers(
), ),
) )
) )
print("") logger.info("")
config.providers[api_str] = updated_providers config.providers[api_str] = updated_providers
@ -182,7 +184,7 @@ def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfi
return StackRunConfig(**config_dict) return StackRunConfig(**config_dict)
if "routing_table" in config_dict: if "routing_table" in config_dict:
print("Upgrading config...") logger.info("Upgrading config...")
config_dict = upgrade_from_routing_table(config_dict) config_dict = upgrade_from_routing_table(config_dict)
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION

View file

@ -5,11 +5,14 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import logging
import threading import threading
from typing import Any, Dict from typing import Any, Dict
from .utils.dynamic import instantiate_class_type from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
_THREAD_LOCAL = threading.local() _THREAD_LOCAL = threading.local()
@ -32,7 +35,7 @@ class NeedsRequestProviderData:
provider_data = validator(**val) provider_data = validator(**val)
return provider_data return provider_data
except Exception as e: except Exception as e:
print("Error parsing provider data", e) log.error("Error parsing provider data", e)
def set_request_provider_data(headers: Dict[str, str]): def set_request_provider_data(headers: Dict[str, str]):
@ -51,7 +54,7 @@ def set_request_provider_data(headers: Dict[str, str]):
try: try:
val = json.loads(val) val = json.loads(val)
except json.JSONDecodeError: except json.JSONDecodeError:
print("Provider data not encoded as a JSON object!", val) log.error("Provider data not encoded as a JSON object!", val)
return return
_THREAD_LOCAL.provider_data_header_value = val _THREAD_LOCAL.provider_data_header_value = val

View file

@ -8,11 +8,12 @@ import inspect
from typing import Any, Dict, List, Set from typing import Any, Dict, List, Set
from termcolor import cprint
from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
import logging
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
@ -33,6 +34,8 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_a
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
class InvalidProviderError(Exception): class InvalidProviderError(Exception):
pass pass
@ -115,11 +118,11 @@ async def resolve_impls(
p = provider_registry[api][provider.provider_type] p = provider_registry[api][provider.provider_type]
if p.deprecation_error: if p.deprecation_error:
cprint(p.deprecation_error, "red", attrs=["bold"]) log.error(p.deprecation_error, "red", attrs=["bold"])
raise InvalidProviderError(p.deprecation_error) raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning: elif p.deprecation_warning:
cprint( log.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
"yellow", "yellow",
attrs=["bold"], attrs=["bold"],
@ -199,10 +202,10 @@ async def resolve_impls(
) )
) )
print(f"Resolved {len(sorted_providers)} providers") log.info(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers: for api_str, provider in sorted_providers:
print(f" {api_str} => {provider.provider_id}") log.info(f" {api_str} => {provider.provider_id}")
print("") log.info("")
impls = {} impls = {}
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis} inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
@ -339,7 +342,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters) obj_params = set(obj_sig.parameters)
obj_params.discard("self") obj_params.discard("self")
if not (proto_params <= obj_params): if not (proto_params <= obj_params):
print( log.error(
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}" f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
) )
missing_methods.append((name, "signature_mismatch")) missing_methods.append((name, "signature_mismatch"))

View file

@ -46,6 +46,10 @@ from llama_stack.distribution.stack import (
replace_env_vars, replace_env_vars,
validate_env_pair, validate_env_pair,
) )
from llama_stack.providers.inline.meta_reference.telemetry.console import (
ConsoleConfig,
ConsoleTelemetryImpl,
)
from .endpoints import get_all_api_endpoints from .endpoints import get_all_api_endpoints
@ -196,7 +200,6 @@ def handle_sigint(app, *args, **kwargs):
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
print("Starting up") print("Starting up")
yield yield
print("Shutting down") print("Shutting down")
for impl in app.__llama_stack_impls__.values(): for impl in app.__llama_stack_impls__.values():
await impl.shutdown() await impl.shutdown()
@ -214,6 +217,7 @@ async def maybe_await(value):
async def sse_generator(event_gen): async def sse_generator(event_gen):
await start_trace("sse_generator")
try: try:
event_gen = await event_gen event_gen = await event_gen
async for item in event_gen: async for item in event_gen:
@ -333,7 +337,7 @@ def main():
print("Run configuration:") print("Run configuration:")
print(yaml.dump(config.model_dump(), indent=2)) print(yaml.dump(config.model_dump(), indent=2))
app = FastAPI() app = FastAPI(lifespan=lifespan)
try: try:
impls = asyncio.run(construct_stack(config)) impls = asyncio.run(construct_stack(config))
@ -342,6 +346,8 @@ def main():
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) setup_logger(impls[Api.telemetry])
else:
setup_logger(ConsoleTelemetryImpl(ConsoleConfig()))
all_endpoints = get_all_api_endpoints() all_endpoints = get_all_api_endpoints()

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any, Dict
@ -40,6 +41,8 @@ from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
LLAMA_STACK_API_VERSION = "alpha" LLAMA_STACK_API_VERSION = "alpha"
@ -93,11 +96,11 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
method = getattr(impls[api], list_method) method = getattr(impls[api], list_method)
for obj in await method(): for obj in await method():
print( log.info(
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}", f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
) )
print("") log.info("")
class EnvVarError(Exception): class EnvVarError(Exception):

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import errno import errno
import logging
import os import os
import pty import pty
import select import select
@ -13,7 +14,7 @@ import subprocess
import sys import sys
import termios import termios
from termcolor import cprint log = logging.getLogger(__name__)
# run a command in a pseudo-terminal, with interrupt handling, # run a command in a pseudo-terminal, with interrupt handling,
@ -29,7 +30,7 @@ def run_with_pty(command):
def sigint_handler(signum, frame): def sigint_handler(signum, frame):
nonlocal ctrl_c_pressed nonlocal ctrl_c_pressed
ctrl_c_pressed = True ctrl_c_pressed = True
cprint("\nCtrl-C detected. Aborting...", "white", attrs=["bold"]) log.info("\nCtrl-C detected. Aborting...")
try: try:
# Set up the signal handler # Set up the signal handler
@ -100,6 +101,6 @@ def run_command(command):
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, error = process.communicate() output, error = process.communicate()
if process.returncode != 0: if process.returncode != 0:
print(f"Error: {error.decode('utf-8')}") log.error(f"Error: {error.decode('utf-8')}")
sys.exit(1) sys.exit(1)
return output.decode("utf-8") return output.decode("utf-8")

View file

@ -6,6 +6,7 @@
import inspect import inspect
import json import json
import logging
from enum import Enum from enum import Enum
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
@ -16,6 +17,8 @@ from pydantic_core import PydanticUndefinedType
from typing_extensions import Annotated from typing_extensions import Annotated
log = logging.getLogger(__name__)
def is_list_of_primitives(field_type): def is_list_of_primitives(field_type):
"""Check if a field type is a List of primitive types.""" """Check if a field type is a List of primitive types."""
@ -111,7 +114,7 @@ def prompt_for_discriminated_union(
if discriminator_value in type_map: if discriminator_value in type_map:
chosen_type = type_map[discriminator_value] chosen_type = type_map[discriminator_value]
print(f"\nConfiguring {chosen_type.__name__}:") log.info(f"\nConfiguring {chosen_type.__name__}:")
if existing_value and ( if existing_value and (
getattr(existing_value, discriminator) != discriminator_value getattr(existing_value, discriminator) != discriminator_value
@ -123,7 +126,7 @@ def prompt_for_discriminated_union(
setattr(sub_config, discriminator, discriminator_value) setattr(sub_config, discriminator, discriminator_value)
return sub_config return sub_config
else: else:
print(f"Invalid {discriminator}. Please try again.") log.error(f"Invalid {discriminator}. Please try again.")
# This is somewhat elaborate, but does not purport to be comprehensive in any way. # This is somewhat elaborate, but does not purport to be comprehensive in any way.
@ -180,7 +183,7 @@ def prompt_for_config(
config_data[field_name] = validated_value config_data[field_name] = validated_value
break break
except KeyError: except KeyError:
print( log.error(
f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}" f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}"
) )
continue continue
@ -197,7 +200,7 @@ def prompt_for_config(
config_data[field_name] = None config_data[field_name] = None
continue continue
nested_type = get_non_none_type(field_type) nested_type = get_non_none_type(field_type)
print(f"Entering sub-configuration for {field_name}:") log.info(f"Entering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(nested_type, existing_value) config_data[field_name] = prompt_for_config(nested_type, existing_value)
elif is_optional(field_type) and is_discriminated_union( elif is_optional(field_type) and is_discriminated_union(
get_non_none_type(field_type) get_non_none_type(field_type)
@ -213,7 +216,7 @@ def prompt_for_config(
existing_value, existing_value,
) )
elif can_recurse(field_type): elif can_recurse(field_type):
print(f"\nEntering sub-configuration for {field_name}:") log.info(f"\nEntering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config( config_data[field_name] = prompt_for_config(
field_type, field_type,
existing_value, existing_value,
@ -240,7 +243,7 @@ def prompt_for_config(
config_data[field_name] = None config_data[field_name] = None
break break
else: else:
print("This field is required. Please provide a value.") log.error("This field is required. Please provide a value.")
continue continue
else: else:
try: try:
@ -264,12 +267,12 @@ def prompt_for_config(
value = [element_type(item) for item in value] value = [element_type(item) for item in value]
except json.JSONDecodeError: except json.JSONDecodeError:
print( log.error(
'Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]' 'Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]'
) )
continue continue
except ValueError as e: except ValueError as e:
print(f"{str(e)}") log.error(f"{str(e)}")
continue continue
elif get_origin(field_type) is dict: elif get_origin(field_type) is dict:
@ -281,7 +284,7 @@ def prompt_for_config(
) )
except json.JSONDecodeError: except json.JSONDecodeError:
print( log.error(
"Invalid JSON. Please enter a valid JSON-encoded dict." "Invalid JSON. Please enter a valid JSON-encoded dict."
) )
continue continue
@ -298,7 +301,7 @@ def prompt_for_config(
value = field_type(user_input) value = field_type(user_input)
except ValueError: except ValueError:
print( log.error(
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}" f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
) )
continue continue
@ -311,6 +314,6 @@ def prompt_for_config(
config_data[field_name] = validated_value config_data[field_name] = validated_value
break break
except ValueError as e: except ValueError as e:
print(f"Validation error: {str(e)}") log.error(f"Validation error: {str(e)}")
return config_type(**config_data) return config_type(**config_data)

View file

@ -6,6 +6,7 @@
import asyncio import asyncio
import copy import copy
import logging
import os import os
import re import re
import secrets import secrets
@ -19,7 +20,6 @@ from urllib.parse import urlparse
import httpx import httpx
from termcolor import cprint
from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
@ -43,6 +43,8 @@ from .tools.builtin import (
) )
from .tools.safety import SafeTool from .tools.safety import SafeTool
log = logging.getLogger(__name__)
def make_random_string(length: int = 8): def make_random_string(length: int = 8):
return "".join( return "".join(
@ -137,7 +139,6 @@ class ChatAgent(ShieldRunnerMixin):
stop_reason=StopReason.end_of_turn, stop_reason=StopReason.end_of_turn,
) )
) )
# print_dialog(messages)
return messages return messages
async def create_session(self, name: str) -> str: async def create_session(self, name: str) -> str:
@ -185,10 +186,8 @@ class ChatAgent(ShieldRunnerMixin):
stream=request.stream, stream=request.stream,
): ):
if isinstance(chunk, CompletionMessage): if isinstance(chunk, CompletionMessage):
cprint( log.info(
f"{chunk.role.capitalize()}: {chunk.content}", f"{chunk.role.capitalize()}: {chunk.content}",
"white",
attrs=["bold"],
) )
output_message = chunk output_message = chunk
continue continue
@ -407,7 +406,7 @@ class ChatAgent(ShieldRunnerMixin):
msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}" msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
else: else:
msg_str = str(msg) msg_str = str(msg)
cprint(f"{msg_str}", color=color) log.info(f"{msg_str}")
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -506,12 +505,12 @@ class ChatAgent(ShieldRunnerMixin):
) )
if n_iter >= self.agent_config.max_infer_iters: if n_iter >= self.agent_config.max_infer_iters:
cprint("Done with MAX iterations, exiting.") log.info("Done with MAX iterations, exiting.")
yield message yield message
break break
if stop_reason == StopReason.out_of_tokens: if stop_reason == StopReason.out_of_tokens:
cprint("Out of token budget, exiting.") log.info("Out of token budget, exiting.")
yield message yield message
break break
@ -525,10 +524,10 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + attachments message.content = [message.content] + attachments
yield message yield message
else: else:
cprint(f"Partial message: {str(message)}", color="green") log.info(f"Partial message: {str(message)}", color="green")
input_messages = input_messages + [message] input_messages = input_messages + [message]
else: else:
cprint(f"{str(message)}", color="green") log.info(f"{str(message)}", color="green")
try: try:
tool_call = message.tool_calls[0] tool_call = message.tool_calls[0]
@ -740,9 +739,8 @@ class ChatAgent(ShieldRunnerMixin):
for c in chunks[: memory.max_chunks]: for c in chunks[: memory.max_chunks]:
tokens += c.token_count tokens += c.token_count
if tokens > memory.max_tokens_in_context: if tokens > memory.max_tokens_in_context:
cprint( log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}", f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
"red",
) )
break break
picked.append(f"id:{c.document_id}; content:{c.content}") picked.append(f"id:{c.document_id}; content:{c.content}")
@ -786,7 +784,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
path = urlparse(uri).path path = urlparse(uri).path
basename = os.path.basename(path) basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}" filepath = f"{tempdir}/{make_random_string() + basename}"
print(f"Downloading {url} -> {filepath}") log.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.get(uri) r = await client.get(uri)
@ -826,20 +824,3 @@ async def execute_tool_call_maybe(
tool = tools_dict[name] tool = tools_dict[name]
result_messages = await tool.run(messages) result_messages = await tool.run(messages)
return result_messages return result_messages
def print_dialog(messages: List[Message]):
for i, m in enumerate(messages):
if m.role == Role.user.value:
color = "red"
elif m.role == Role.assistant.value:
color = "white"
elif m.role == Role.ipython.value:
color = "yellow"
elif m.role == Role.system.value:
color = "green"
else:
color = "white"
s = str(m)
cprint(f"{i} ::: {s[:100]}...", color=color)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import logging
import uuid import uuid
from datetime import datetime from datetime import datetime
@ -15,6 +15,8 @@ from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
class AgentSessionInfo(BaseModel): class AgentSessionInfo(BaseModel):
session_id: str session_id: str
@ -78,7 +80,7 @@ class AgentPersistence:
turn = Turn(**json.loads(value)) turn = Turn(**json.loads(value))
turns.append(turn) turns.append(turn)
except Exception as e: except Exception as e:
print(f"Error parsing turn: {e}") log.error(f"Error parsing turn: {e}")
continue continue
turns.sort(key=lambda x: (x.completed_at or datetime.min)) turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns return turns

View file

@ -10,8 +10,6 @@ from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403 from llama_models.llama3.api import * # noqa: F403
from termcolor import cprint # noqa: F401
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
DefaultMemoryQueryGeneratorConfig, DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig, LLMMemoryQueryGeneratorConfig,
@ -36,7 +34,6 @@ async def generate_rag_query(
query = await llm_rag_query_generator(config, messages, **kwargs) query = await llm_rag_query_generator(config, messages, **kwargs)
else: else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}") raise NotImplementedError(f"Unsupported memory query generator {config.type}")
# cprint(f"Generated query >>>: {query}", color="green")
return query return query

View file

@ -5,14 +5,16 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import logging
from typing import List from typing import List
from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.datatypes import Message
from termcolor import cprint
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
log = logging.getLogger(__name__)
class SafetyException(Exception): # noqa: N818 class SafetyException(Exception): # noqa: N818
def __init__(self, violation: SafetyViolation): def __init__(self, violation: SafetyViolation):
@ -51,7 +53,4 @@ class ShieldRunnerMixin:
if violation.violation_level == ViolationLevel.ERROR: if violation.violation_level == ViolationLevel.ERROR:
raise SafetyException(violation) raise SafetyException(violation)
elif violation.violation_level == ViolationLevel.WARN: elif violation.violation_level == ViolationLevel.WARN:
cprint( log.warning(f"[Warn]{identifier} raised a warning")
f"[Warn]{identifier} raised a warning",
color="red",
)

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import logging
import re import re
import tempfile import tempfile
@ -12,7 +13,6 @@ from abc import abstractmethod
from typing import List, Optional from typing import List, Optional
import requests import requests
from termcolor import cprint
from .ipython_tool.code_execution import ( from .ipython_tool.code_execution import (
CodeExecutionContext, CodeExecutionContext,
@ -27,6 +27,9 @@ from llama_stack.apis.agents import * # noqa: F403
from .base import BaseTool from .base import BaseTool
log = logging.getLogger(__name__)
def interpret_content_as_attachment(content: str) -> Optional[Attachment]: def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content) match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
if match: if match:
@ -383,7 +386,7 @@ class CodeInterpreterTool(BaseTool):
if res_out != "": if res_out != "":
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"]) pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
if out_type == "stderr": if out_type == "stderr":
cprint(f"ipython tool error: ↓\n{res_out}", color="red") log.error(f"ipython tool error: ↓\n{res_out}")
message = ToolResponseMessage( message = ToolResponseMessage(
call_id=tool_call.call_id, call_id=tool_call.call_id,

View file

@ -11,6 +11,7 @@ A custom Matplotlib backend that overrides the show method to return image bytes
import base64 import base64
import io import io
import json as _json import json as _json
import logging
import matplotlib import matplotlib
from matplotlib.backend_bases import FigureManagerBase from matplotlib.backend_bases import FigureManagerBase
@ -18,6 +19,8 @@ from matplotlib.backend_bases import FigureManagerBase
# Import necessary components from Matplotlib # Import necessary components from Matplotlib
from matplotlib.backends.backend_agg import FigureCanvasAgg from matplotlib.backends.backend_agg import FigureCanvasAgg
log = logging.getLogger(__name__)
class CustomFigureCanvas(FigureCanvasAgg): class CustomFigureCanvas(FigureCanvasAgg):
def show(self): def show(self):
@ -80,7 +83,7 @@ def show():
) )
req_con.send_bytes(_json_dump.encode("utf-8")) req_con.send_bytes(_json_dump.encode("utf-8"))
resp = _json.loads(resp_con.recv_bytes().decode("utf-8")) resp = _json.loads(resp_con.recv_bytes().decode("utf-8"))
print(resp) log.info(resp)
FigureCanvas = CustomFigureCanvas FigureCanvas = CustomFigureCanvas

View file

@ -8,6 +8,7 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json import json
import logging
import math import math
import os import os
import sys import sys
@ -31,7 +32,6 @@ from llama_models.llama3.reference_impl.multimodal.model import (
) )
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
@ -50,6 +50,8 @@ from .config import (
MetaReferenceQuantizedInferenceConfig, MetaReferenceQuantizedInferenceConfig,
) )
log = logging.getLogger(__name__)
def model_checkpoint_dir(model) -> str: def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor())) checkpoint_dir = Path(model_local_dir(model.descriptor()))
@ -185,7 +187,7 @@ class Llama:
model = Transformer(model_args) model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds") log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args, llama_model) return Llama(model, tokenizer, model_args, llama_model)
def __init__( def __init__(
@ -221,7 +223,7 @@ class Llama:
self.formatter.vision_token if t == 128256 else t self.formatter.vision_token if t == 128256 else t
for t in model_input.tokens for t in model_input.tokens
] ]
cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red") log.info("Input to model -> " + self.tokenizer.decode(input_tokens))
prompt_tokens = [model_input.tokens] prompt_tokens = [model_input.tokens]
bsz = 1 bsz = 1
@ -231,9 +233,7 @@ class Llama:
max_prompt_len = max(len(t) for t in prompt_tokens) max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len: if max_prompt_len >= params.max_seq_len:
cprint( log.error(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}")
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red"
)
return return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len) total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import logging
from typing import AsyncGenerator, List from typing import AsyncGenerator, List
@ -25,6 +26,7 @@ from .config import MetaReferenceInferenceConfig
from .generation import Llama from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__)
# there's a single model parallel process running serving the model. for now, # there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process. # we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1) SEMAPHORE = asyncio.Semaphore(1)
@ -49,7 +51,7 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
# verify that the checkpoint actually is for this model lol # verify that the checkpoint actually is for this model lol
async def initialize(self) -> None: async def initialize(self) -> None:
print(f"Loading model `{self.model.descriptor()}`") log.info(f"Loading model `{self.model.descriptor()}`")
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config) self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start() self.generator.start()

View file

@ -11,6 +11,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import logging
import multiprocessing import multiprocessing
import os import os
import tempfile import tempfile
@ -37,6 +38,8 @@ from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .generation import TokenResult from .generation import TokenResult
log = logging.getLogger(__name__)
class ProcessingMessageName(str, Enum): class ProcessingMessageName(str, Enum):
ready_request = "ready_request" ready_request = "ready_request"
@ -183,16 +186,16 @@ def retrieve_requests(reply_socket_url: str):
group=get_model_parallel_group(), group=get_model_parallel_group(),
) )
if isinstance(updates[0], CancelSentinel): if isinstance(updates[0], CancelSentinel):
print("quitting generation loop because request was cancelled") log.info(
"quitting generation loop because request was cancelled"
)
break break
if mp_rank_0(): if mp_rank_0():
send_obj(EndSentinel()) send_obj(EndSentinel())
except Exception as e: except Exception as e:
print(f"[debug] got exception {e}") log.exception("exception in generation loop")
import traceback
traceback.print_exc()
if mp_rank_0(): if mp_rank_0():
send_obj(ExceptionResponse(error=str(e))) send_obj(ExceptionResponse(error=str(e)))
@ -252,7 +255,7 @@ def worker_process_entrypoint(
except StopIteration: except StopIteration:
break break
print("[debug] worker process done") log.info("[debug] worker process done")
def launch_dist_group( def launch_dist_group(
@ -313,7 +316,7 @@ def start_model_parallel_process(
request_socket.send(encode_msg(ReadyRequest())) request_socket.send(encode_msg(ReadyRequest()))
response = request_socket.recv() response = request_socket.recv()
print("Loaded model...") log.info("Loaded model...")
return request_socket, process return request_socket, process
@ -361,7 +364,7 @@ class ModelParallelProcessGroup:
break break
if isinstance(obj, ExceptionResponse): if isinstance(obj, ExceptionResponse):
print(f"[debug] got exception {obj.error}") log.error(f"[debug] got exception {obj.error}")
raise Exception(obj.error) raise Exception(obj.error)
if isinstance(obj, TaskResponse): if isinstance(obj, TaskResponse):

View file

@ -8,14 +8,20 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import collections import collections
import logging
from typing import Optional, Type from typing import Optional, Type
log = logging.getLogger(__name__)
try: try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401 import fbgemm_gpu.experimental.gen_ai # noqa: F401
print("Using efficient FP8 operators in FBGEMM.") log.info("Using efficient FP8 operators in FBGEMM.")
except ImportError: except ImportError:
print("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.") log.error(
"No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt."
)
raise raise
import torch import torch

View file

@ -7,6 +7,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import logging
import os import os
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -21,7 +22,6 @@ from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from termcolor import cprint
from torch import nn, Tensor from torch import nn, Tensor
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
@ -30,6 +30,8 @@ from llama_stack.apis.inference import QuantizationType
from ..config import MetaReferenceQuantizedInferenceConfig from ..config import MetaReferenceQuantizedInferenceConfig
log = logging.getLogger(__name__)
def swiglu_wrapper( def swiglu_wrapper(
self, self,
@ -60,7 +62,7 @@ def convert_to_fp8_quantized_model(
# Move weights to GPU with quantization # Move weights to GPU with quantization
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value: if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
cprint("Loading fp8 scales...", "yellow") log.info("Loading fp8 scales...")
fp8_scales_path = os.path.join( fp8_scales_path = os.path.join(
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt" checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
) )
@ -85,7 +87,7 @@ def convert_to_fp8_quantized_model(
fp8_activation_scale_ub, fp8_activation_scale_ub,
) )
else: else:
cprint("Quantizing fp8 weights from bf16...", "yellow") log.info("Quantizing fp8 weights from bf16...")
for block in model.layers: for block in model.layers:
if isinstance(block, TransformerBlock): if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):

View file

@ -8,6 +8,7 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json import json
import logging
import os import os
import shutil import shutil
import sys import sys
@ -32,6 +33,8 @@ from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impl
quantize_fp8, quantize_fp8,
) )
log = logging.getLogger(__name__)
def main( def main(
ckpt_dir: str, ckpt_dir: str,
@ -102,7 +105,7 @@ def main(
else: else:
torch.set_default_tensor_type(torch.cuda.HalfTensor) torch.set_default_tensor_type(torch.cuda.HalfTensor)
print(ckpt_path) log.info(ckpt_path)
assert ( assert (
quantized_ckpt_dir is not None quantized_ckpt_dir is not None
), "QUantized checkpoint directory should not be None" ), "QUantized checkpoint directory should not be None"

View file

@ -4,10 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel from pydantic import BaseModel
class LogFormat(Enum):
TEXT = "text"
JSON = "json"
@json_schema_type @json_schema_type
class ConsoleConfig(BaseModel): ... class ConsoleConfig(BaseModel):
log_format: LogFormat = LogFormat.JSON

View file

@ -4,8 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
from typing import Optional from typing import Optional
from .config import LogFormat
from llama_stack.apis.telemetry import * # noqa: F403 from llama_stack.apis.telemetry import * # noqa: F403
from .config import ConsoleConfig from .config import ConsoleConfig
@ -38,7 +41,11 @@ class ConsoleTelemetryImpl(Telemetry):
span_name = ".".join(names) if names else None span_name = ".".join(names) if names else None
formatted = format_event(event, span_name) if self.config.log_format == LogFormat.JSON:
formatted = format_event_json(event, span_name)
else:
formatted = format_event_text(event, span_name)
if formatted: if formatted:
print(formatted) print(formatted)
@ -69,7 +76,7 @@ SEVERITY_COLORS = {
} }
def format_event(event: Event, span_name: str) -> Optional[str]: def format_event_text(event: Event, span_name: str) -> Optional[str]:
timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3] timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3]
span = "" span = ""
if span_name: if span_name:
@ -87,3 +94,23 @@ def format_event(event: Event, span_name: str) -> Optional[str]:
return None return None
return f"Unknown event type: {event}" return f"Unknown event type: {event}"
def format_event_json(event: Event, span_name: str) -> Optional[str]:
base_data = {
"timestamp": event.timestamp.isoformat(),
"trace_id": event.trace_id,
"span_id": event.span_id,
"span_name": span_name,
}
if isinstance(event, UnstructuredLogEvent):
base_data.update(
{"type": "log", "severity": event.severity.name, "message": event.message}
)
return json.dumps(base_data)
elif isinstance(event, StructuredLogEvent):
return None
return json.dumps({"error": f"Unknown event type: {event}"})

View file

@ -4,16 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from typing import Any, Dict, List from typing import Any, Dict, List
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
from termcolor import cprint
from .config import CodeScannerConfig from .config import CodeScannerConfig
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
log = logging.getLogger(__name__)
ALLOWED_CODE_SCANNER_MODEL_IDS = [ ALLOWED_CODE_SCANNER_MODEL_IDS = [
"CodeScanner", "CodeScanner",
"CodeShield", "CodeShield",
@ -49,7 +49,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
from codeshield.cs import CodeShield from codeshield.cs import CodeShield
text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages]) text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages])
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta") log.info(f"Running CodeScannerShield on {text[50:]}")
result = await CodeShield.scan_code(text) result = await CodeShield.scan_code(text)
violation = None violation = None

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from typing import Any, Dict, List from typing import Any, Dict, List
import torch import torch
from termcolor import cprint
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer
@ -20,6 +20,7 @@ from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import PromptGuardConfig, PromptGuardType from .config import PromptGuardConfig, PromptGuardType
log = logging.getLogger(__name__)
PROMPT_GUARD_MODEL = "Prompt-Guard-86M" PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
@ -93,9 +94,8 @@ class PromptGuardShield:
probabilities = torch.softmax(logits / self.temperature, dim=-1) probabilities = torch.softmax(logits / self.temperature, dim=-1)
score_embedded = probabilities[0, 1].item() score_embedded = probabilities[0, 1].item()
score_malicious = probabilities[0, 2].item() score_malicious = probabilities[0, 2].item()
cprint( log.info(
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}", f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
color="magenta",
) )
violation = None violation = None

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from typing import AsyncGenerator from typing import AsyncGenerator
import httpx import httpx
@ -39,6 +40,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
request_has_media, request_has_media,
) )
log = logging.getLogger(__name__)
model_aliases = [ model_aliases = [
build_model_alias( build_model_alias(
@ -105,7 +107,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return AsyncClient(host=self.url) return AsyncClient(host=self.url)
async def initialize(self) -> None: async def initialize(self) -> None:
print(f"checking connectivity to Ollama at `{self.url}`...") log.info(f"checking connectivity to Ollama at `{self.url}`...")
try: try:
await self.client.ps() await self.client.ps()
except httpx.ConnectError as e: except httpx.ConnectError as e:

View file

@ -34,7 +34,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
logger = logging.getLogger(__name__) log = logging.getLogger(__name__)
class _HfAdapter(Inference, ModelsProtocolPrivate): class _HfAdapter(Inference, ModelsProtocolPrivate):
@ -264,7 +264,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
class TGIAdapter(_HfAdapter): class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None: async def initialize(self, config: TGIImplConfig) -> None:
print(f"Initializing TGI client with url={config.url}") log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(model=config.url, token=config.api_token) self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
endpoint_info = await self.client.get_endpoint_info() endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"] self.max_tokens = endpoint_info["max_total_tokens"]

View file

@ -3,6 +3,8 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from typing import AsyncGenerator from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
@ -34,6 +36,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig from .config import VLLMInferenceAdapterConfig
log = logging.getLogger(__name__)
def build_model_aliases(): def build_model_aliases():
return [ return [
build_model_alias( build_model_alias(
@ -53,7 +58,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self.client = None self.client = None
async def initialize(self) -> None: async def initialize(self) -> None:
print(f"Initializing VLLM client with base_url={self.config.url}") log.info(f"Initializing VLLM client with base_url={self.config.url}")
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
async def shutdown(self) -> None: async def shutdown(self) -> None:

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import logging
from typing import List from typing import List
from urllib.parse import urlparse from urllib.parse import urlparse
@ -21,6 +22,8 @@ from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex, EmbeddingIndex,
) )
log = logging.getLogger(__name__)
class ChromaIndex(EmbeddingIndex): class ChromaIndex(EmbeddingIndex):
def __init__(self, client: chromadb.AsyncHttpClient, collection): def __init__(self, client: chromadb.AsyncHttpClient, collection):
@ -56,10 +59,7 @@ class ChromaIndex(EmbeddingIndex):
doc = json.loads(doc) doc = json.loads(doc)
chunk = Chunk(**doc) chunk = Chunk(**doc)
except Exception: except Exception:
import traceback log.exception(f"Failed to parse document: {doc}")
traceback.print_exc()
print(f"Failed to parse document: {doc}")
continue continue
chunks.append(chunk) chunks.append(chunk)
@ -73,7 +73,7 @@ class ChromaIndex(EmbeddingIndex):
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, url: str) -> None: def __init__(self, url: str) -> None:
print(f"Initializing ChromaMemoryAdapter with url: {url}") log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/") url = url.rstrip("/")
parsed = urlparse(url) parsed = urlparse(url)
@ -88,12 +88,10 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def initialize(self) -> None: async def initialize(self) -> None:
try: try:
print(f"Connecting to Chroma server at: {self.host}:{self.port}") log.info(f"Connecting to Chroma server at: {self.host}:{self.port}")
self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port) self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port)
except Exception as e: except Exception as e:
import traceback log.exception("Could not connect to Chroma server")
traceback.print_exc()
raise RuntimeError("Could not connect to Chroma server") from e raise RuntimeError("Could not connect to Chroma server") from e
async def shutdown(self) -> None: async def shutdown(self) -> None:
@ -123,10 +121,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
data = json.loads(collection.metadata["bank"]) data = json.loads(collection.metadata["bank"])
bank = parse_obj_as(VectorMemoryBank, data) bank = parse_obj_as(VectorMemoryBank, data)
except Exception: except Exception:
import traceback log.exception(f"Failed to parse bank: {collection.metadata}")
traceback.print_exc()
print(f"Failed to parse bank: {collection.metadata}")
continue continue
index = BankWithIndex( index = BankWithIndex(

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from typing import List, Tuple from typing import List, Tuple
import psycopg2 import psycopg2
@ -24,6 +25,8 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import PGVectorConfig from .config import PGVectorConfig
log = logging.getLogger(__name__)
def check_extension_version(cur): def check_extension_version(cur):
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'") cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
@ -124,7 +127,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
self.cache = {} self.cache = {}
async def initialize(self) -> None: async def initialize(self) -> None:
print(f"Initializing PGVector memory adapter with config: {self.config}") log.info(f"Initializing PGVector memory adapter with config: {self.config}")
try: try:
self.conn = psycopg2.connect( self.conn = psycopg2.connect(
host=self.config.host, host=self.config.host,
@ -138,7 +141,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
version = check_extension_version(self.cursor) version = check_extension_version(self.cursor)
if version: if version:
print(f"Vector extension version: {version}") log.info(f"Vector extension version: {version}")
else: else:
raise RuntimeError("Vector extension is not installed.") raise RuntimeError("Vector extension is not installed.")
@ -151,9 +154,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
""" """
) )
except Exception as e: except Exception as e:
import traceback log.exception("Could not connect to PGVector database server")
traceback.print_exc()
raise RuntimeError("Could not connect to PGVector database server") from e raise RuntimeError("Could not connect to PGVector database server") from e
async def shutdown(self) -> None: async def shutdown(self) -> None:

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import traceback import logging
import uuid import uuid
from typing import Any, Dict, List from typing import Any, Dict, List
@ -23,6 +23,7 @@ from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex, EmbeddingIndex,
) )
log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id" CHUNK_ID_KEY = "_chunk_id"
@ -90,7 +91,7 @@ class QdrantIndex(EmbeddingIndex):
try: try:
chunk = Chunk(**point.payload["chunk_content"]) chunk = Chunk(**point.payload["chunk_content"])
except Exception: except Exception:
traceback.print_exc() log.exception("Failed to parse chunk")
continue continue
chunks.append(chunk) chunks.append(chunk)

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -22,6 +23,8 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import WeaviateConfig, WeaviateRequestProviderData from .config import WeaviateConfig, WeaviateRequestProviderData
log = logging.getLogger(__name__)
class WeaviateIndex(EmbeddingIndex): class WeaviateIndex(EmbeddingIndex):
def __init__(self, client: weaviate.Client, collection_name: str): def __init__(self, client: weaviate.Client, collection_name: str):
@ -69,10 +72,7 @@ class WeaviateIndex(EmbeddingIndex):
chunk_dict = json.loads(chunk_json) chunk_dict = json.loads(chunk_json)
chunk = Chunk(**chunk_dict) chunk = Chunk(**chunk_dict)
except Exception: except Exception:
import traceback log.exception(f"Failed to parse document: {chunk_json}")
traceback.print_exc()
print(f"Failed to parse document: {chunk_json}")
continue continue
chunks.append(chunk) chunks.append(chunk)

View file

@ -7,14 +7,13 @@
import base64 import base64
import io import io
import json import json
import logging
from typing import Tuple from typing import Tuple
import httpx import httpx
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from PIL import Image as PIL_Image from PIL import Image as PIL_Image
from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_models.datatypes import ModelFamily from llama_models.datatypes import ModelFamily
@ -29,6 +28,8 @@ from llama_models.sku_list import resolve_model
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
log = logging.getLogger(__name__)
def content_has_media(content: InterleavedTextMedia): def content_has_media(content: InterleavedTextMedia):
def _has_media_content(c): def _has_media_content(c):
@ -175,13 +176,13 @@ def chat_completion_request_to_messages(
""" """
model = resolve_model(llama_model) model = resolve_model(llama_model)
if model is None: if model is None:
cprint(f"Could not resolve model {llama_model}", color="red") log.error(f"Could not resolve model {llama_model}")
return request.messages return request.messages
allowed_models = supported_inference_models() allowed_models = supported_inference_models()
descriptors = [m.descriptor() for m in allowed_models] descriptors = [m.descriptor() for m in allowed_models]
if model.descriptor() not in descriptors: if model.descriptor() not in descriptors:
cprint(f"Unsupported inference model? {model.descriptor()}", color="red") log.error(f"Unsupported inference model? {model.descriptor()}")
return request.messages return request.messages
if model.model_family == ModelFamily.llama3_1 or ( if model.model_family == ModelFamily.llama3_1 or (

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional
@ -13,6 +14,8 @@ from psycopg2.extras import DictCursor
from ..api import KVStore from ..api import KVStore
from ..config import PostgresKVStoreConfig from ..config import PostgresKVStoreConfig
log = logging.getLogger(__name__)
class PostgresKVStoreImpl(KVStore): class PostgresKVStoreImpl(KVStore):
def __init__(self, config: PostgresKVStoreConfig): def __init__(self, config: PostgresKVStoreConfig):
@ -43,9 +46,8 @@ class PostgresKVStoreImpl(KVStore):
""" """
) )
except Exception as e: except Exception as e:
import traceback
traceback.print_exc() log.exception("Could not connect to PostgreSQL database server")
raise RuntimeError("Could not connect to PostgreSQL database server") from e raise RuntimeError("Could not connect to PostgreSQL database server") from e
def _namespaced_key(self, key: str) -> str: def _namespaced_key(self, key: str) -> str:

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import base64 import base64
import io import io
import logging
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
@ -16,13 +17,14 @@ import httpx
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from pypdf import PdfReader from pypdf import PdfReader
from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
log = logging.getLogger(__name__)
ALL_MINILM_L6_V2_DIMENSION = 384 ALL_MINILM_L6_V2_DIMENSION = 384
EMBEDDING_MODELS = {} EMBEDDING_MODELS = {}
@ -35,7 +37,7 @@ def get_embedding_model(model: str) -> "SentenceTransformer":
if loaded_model is not None: if loaded_model is not None:
return loaded_model return loaded_model
print(f"Loading sentence transformer for {model}...") log.info(f"Loading sentence transformer for {model}...")
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
loaded_model = SentenceTransformer(model) loaded_model = SentenceTransformer(model)
@ -92,7 +94,7 @@ def content_from_data(data_url: str) -> str:
return "\n".join([page.extract_text() for page in pdf_reader.pages]) return "\n".join([page.extract_text() for page in pdf_reader.pages])
else: else:
cprint("Could not extract content from data_url properly.", color="red") log.error("Could not extract content from data_url properly.")
return "" return ""

View file

@ -17,6 +17,8 @@ from typing import Any, Callable, Dict, List
from llama_stack.apis.telemetry import * # noqa: F403 from llama_stack.apis.telemetry import * # noqa: F403
log = logging.getLogger(__name__)
def generate_short_uuid(len: int = 12): def generate_short_uuid(len: int = 12):
full_uuid = uuid.uuid4() full_uuid = uuid.uuid4()
@ -40,7 +42,7 @@ class BackgroundLogger:
try: try:
self.log_queue.put_nowait(event) self.log_queue.put_nowait(event)
except queue.Full: except queue.Full:
print("Log queue is full, dropping event") log.error("Log queue is full, dropping event")
def _process_logs(self): def _process_logs(self):
while True: while True:
@ -125,7 +127,7 @@ async def start_trace(name: str, attributes: Dict[str, Any] = None):
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None: if BACKGROUND_LOGGER is None:
print("No Telemetry implementation set. Skipping trace initialization...") log.info("No Telemetry implementation set. Skipping trace initialization...")
return return
trace_id = generate_short_uuid() trace_id = generate_short_uuid()