use logging instead of prints (#499)

# What does this PR do?

This PR moves all print statements to use logging. Things changed:
- Had to add `await start_trace("sse_generator")` to server.py to
actually get tracing working. else was not seeing any logs
- If no telemetry provider is provided in the run.yaml, we will write to
stdout
- by default, the logs are going to be in JSON, but we expose an option
to configure to output in a human readable way.
This commit is contained in:
Dinesh Yeduguru 2024-11-21 11:32:53 -08:00 committed by GitHub
parent 4e1105e563
commit 6395dadc2b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 234 additions and 163 deletions

View file

@ -8,6 +8,7 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import logging
import math
import os
import sys
@ -31,7 +32,6 @@ from llama_models.llama3.reference_impl.multimodal.model import (
)
from llama_models.sku_list import resolve_model
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403
@ -50,6 +50,8 @@ from .config import (
MetaReferenceQuantizedInferenceConfig,
)
log = logging.getLogger(__name__)
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
@ -185,7 +187,7 @@ class Llama:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args, llama_model)
def __init__(
@ -221,7 +223,7 @@ class Llama:
self.formatter.vision_token if t == 128256 else t
for t in model_input.tokens
]
cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
log.info("Input to model -> " + self.tokenizer.decode(input_tokens))
prompt_tokens = [model_input.tokens]
bsz = 1
@ -231,9 +233,7 @@ class Llama:
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red"
)
log.error(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}")
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)