add EventLogger for inference

This commit is contained in:
Ashwin Bharambe 2024-07-22 15:11:34 -07:00
parent 7574ffb25f
commit bbfd8a587e
2 changed files with 38 additions and 3 deletions

View file

@ -13,6 +13,7 @@ from .api import (
Inference, Inference,
UserMessage, UserMessage,
) )
from .event_logger import EventLogger
class InferenceClient(Inference): class InferenceClient(Inference):
@ -56,14 +57,15 @@ async def run_main(host: str, port: int):
messages=[message], messages=[message],
stream=True, stream=True,
) )
async for event in client.chat_completion( iterator = client.chat_completion(
ChatCompletionRequest( ChatCompletionRequest(
model=InstructModel.llama3_8b_chat, model=InstructModel.llama3_8b_chat,
messages=[message], messages=[message],
stream=True, stream=True,
) )
): )
print(event) async for log in EventLogger().log(iterator):
log.print()
def main(host: str, port: int): def main(host: str, port: int):

View file

@ -0,0 +1,33 @@
from termcolor import cprint
from llama_toolchain.inference.api import (
ChatCompletionResponseEventType,
)
class LogEvent:
def __init__(
self,
content: str = "",
end: str = "\n",
color="white",
):
self.content = content
self.color = color
self.end = "\n" if end is None else end
def print(self, flush=True):
cprint(f"{self.content}", color=self.color, end=self.end, flush=flush)
class EventLogger:
async def log(self, event_generator, stream=True):
async for chunk in event_generator:
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
yield LogEvent("Assistant> ", color="cyan", end="")
elif event.event_type == ChatCompletionResponseEventType.progress:
yield LogEvent(event.delta, color="yellow", end="")
elif event.event_type == ChatCompletionResponseEventType.complete:
yield LogEvent("")