use logging instead of prints (#499)

# What does this PR do?

This PR moves all print statements to use logging. Things changed:
- Had to add `await start_trace("sse_generator")` to server.py to
actually get tracing working. else was not seeing any logs
- If no telemetry provider is provided in the run.yaml, we will write to
stdout
- by default, the logs are going to be in JSON, but we expose an option
to configure to output in a human readable way.
This commit is contained in:
Dinesh Yeduguru 2024-11-21 11:32:53 -08:00 committed by GitHub
parent 4e1105e563
commit 6395dadc2b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 234 additions and 163 deletions

View file

@ -6,6 +6,7 @@
import asyncio
import copy
import logging
import os
import re
import secrets
@ -19,7 +20,6 @@ from urllib.parse import urlparse
import httpx
from termcolor import cprint
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
@ -43,6 +43,8 @@ from .tools.builtin import (
)
from .tools.safety import SafeTool
log = logging.getLogger(__name__)
def make_random_string(length: int = 8):
return "".join(
@ -137,7 +139,6 @@ class ChatAgent(ShieldRunnerMixin):
stop_reason=StopReason.end_of_turn,
)
)
# print_dialog(messages)
return messages
async def create_session(self, name: str) -> str:
@ -185,10 +186,8 @@ class ChatAgent(ShieldRunnerMixin):
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
cprint(
log.info(
f"{chunk.role.capitalize()}: {chunk.content}",
"white",
attrs=["bold"],
)
output_message = chunk
continue
@ -407,7 +406,7 @@ class ChatAgent(ShieldRunnerMixin):
msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
else:
msg_str = str(msg)
cprint(f"{msg_str}", color=color)
log.info(f"{msg_str}")
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -506,12 +505,12 @@ class ChatAgent(ShieldRunnerMixin):
)
if n_iter >= self.agent_config.max_infer_iters:
cprint("Done with MAX iterations, exiting.")
log.info("Done with MAX iterations, exiting.")
yield message
break
if stop_reason == StopReason.out_of_tokens:
cprint("Out of token budget, exiting.")
log.info("Out of token budget, exiting.")
yield message
break
@ -525,10 +524,10 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + attachments
yield message
else:
cprint(f"Partial message: {str(message)}", color="green")
log.info(f"Partial message: {str(message)}", color="green")
input_messages = input_messages + [message]
else:
cprint(f"{str(message)}", color="green")
log.info(f"{str(message)}", color="green")
try:
tool_call = message.tool_calls[0]
@ -740,9 +739,8 @@ class ChatAgent(ShieldRunnerMixin):
for c in chunks[: memory.max_chunks]:
tokens += c.token_count
if tokens > memory.max_tokens_in_context:
cprint(
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
"red",
)
break
picked.append(f"id:{c.document_id}; content:{c.content}")
@ -786,7 +784,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
path = urlparse(uri).path
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
print(f"Downloading {url} -> {filepath}")
log.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)
@ -826,20 +824,3 @@ async def execute_tool_call_maybe(
tool = tools_dict[name]
result_messages = await tool.run(messages)
return result_messages
def print_dialog(messages: List[Message]):
for i, m in enumerate(messages):
if m.role == Role.user.value:
color = "red"
elif m.role == Role.assistant.value:
color = "white"
elif m.role == Role.ipython.value:
color = "yellow"
elif m.role == Role.system.value:
color = "green"
else:
color = "white"
s = str(m)
cprint(f"{i} ::: {s[:100]}...", color=color)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
import logging
import uuid
from datetime import datetime
@ -15,6 +15,8 @@ from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
class AgentSessionInfo(BaseModel):
session_id: str
@ -78,7 +80,7 @@ class AgentPersistence:
turn = Turn(**json.loads(value))
turns.append(turn)
except Exception as e:
print(f"Error parsing turn: {e}")
log.error(f"Error parsing turn: {e}")
continue
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns

View file

@ -10,8 +10,6 @@ from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403
from termcolor import cprint # noqa: F401
from llama_stack.apis.agents import (
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
@ -36,7 +34,6 @@ async def generate_rag_query(
query = await llm_rag_query_generator(config, messages, **kwargs)
else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
# cprint(f"Generated query >>>: {query}", color="green")
return query

View file

@ -5,14 +5,16 @@
# the root directory of this source tree.
import asyncio
import logging
from typing import List
from llama_models.llama3.api.datatypes import Message
from termcolor import cprint
from llama_stack.apis.safety import * # noqa: F403
log = logging.getLogger(__name__)
class SafetyException(Exception): # noqa: N818
def __init__(self, violation: SafetyViolation):
@ -51,7 +53,4 @@ class ShieldRunnerMixin:
if violation.violation_level == ViolationLevel.ERROR:
raise SafetyException(violation)
elif violation.violation_level == ViolationLevel.WARN:
cprint(
f"[Warn]{identifier} raised a warning",
color="red",
)
log.warning(f"[Warn]{identifier} raised a warning")

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import json
import logging
import re
import tempfile
@ -12,7 +13,6 @@ from abc import abstractmethod
from typing import List, Optional
import requests
from termcolor import cprint
from .ipython_tool.code_execution import (
CodeExecutionContext,
@ -27,6 +27,9 @@ from llama_stack.apis.agents import * # noqa: F403
from .base import BaseTool
log = logging.getLogger(__name__)
def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
if match:
@ -383,7 +386,7 @@ class CodeInterpreterTool(BaseTool):
if res_out != "":
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
if out_type == "stderr":
cprint(f"ipython tool error: ↓\n{res_out}", color="red")
log.error(f"ipython tool error: ↓\n{res_out}")
message = ToolResponseMessage(
call_id=tool_call.call_id,

View file

@ -11,6 +11,7 @@ A custom Matplotlib backend that overrides the show method to return image bytes
import base64
import io
import json as _json
import logging
import matplotlib
from matplotlib.backend_bases import FigureManagerBase
@ -18,6 +19,8 @@ from matplotlib.backend_bases import FigureManagerBase
# Import necessary components from Matplotlib
from matplotlib.backends.backend_agg import FigureCanvasAgg
log = logging.getLogger(__name__)
class CustomFigureCanvas(FigureCanvasAgg):
def show(self):
@ -80,7 +83,7 @@ def show():
)
req_con.send_bytes(_json_dump.encode("utf-8"))
resp = _json.loads(resp_con.recv_bytes().decode("utf-8"))
print(resp)
log.info(resp)
FigureCanvas = CustomFigureCanvas