diff --git a/llama_toolchain/inference/client.py b/llama_toolchain/inference/client.py index 317637efa..2e8a36161 100644 --- a/llama_toolchain/inference/client.py +++ b/llama_toolchain/inference/client.py @@ -13,6 +13,7 @@ from .api import ( Inference, UserMessage, ) +from .event_logger import EventLogger class InferenceClient(Inference): @@ -56,14 +57,15 @@ async def run_main(host: str, port: int): messages=[message], stream=True, ) - async for event in client.chat_completion( + iterator = client.chat_completion( ChatCompletionRequest( model=InstructModel.llama3_8b_chat, messages=[message], stream=True, ) - ): - print(event) + ) + async for log in EventLogger().log(iterator): + log.print() def main(host: str, port: int): diff --git a/llama_toolchain/inference/event_logger.py b/llama_toolchain/inference/event_logger.py new file mode 100644 index 000000000..71d472ee1 --- /dev/null +++ b/llama_toolchain/inference/event_logger.py @@ -0,0 +1,33 @@ + +from termcolor import cprint +from llama_toolchain.inference.api import ( + ChatCompletionResponseEventType, +) + + +class LogEvent: + def __init__( + self, + content: str = "", + end: str = "\n", + color="white", + ): + self.content = content + self.color = color + self.end = "\n" if end is None else end + + def print(self, flush=True): + cprint(f"{self.content}", color=self.color, end=self.end, flush=flush) + + +class EventLogger: + async def log(self, event_generator, stream=True): + async for chunk in event_generator: + event = chunk.event + if event.event_type == ChatCompletionResponseEventType.start: + yield LogEvent("Assistant> ", color="cyan", end="") + elif event.event_type == ChatCompletionResponseEventType.progress: + yield LogEvent(event.delta, color="yellow", end="") + elif event.event_type == ChatCompletionResponseEventType.complete: + yield LogEvent("") +