From cc98fbb0585f02d2818b2ea0291fdea3ab29ae5f Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Tue, 30 Jul 2024 14:25:50 -0700 Subject: [PATCH] fix non-streaming api in inference server --- llama_toolchain/inference/client.py | 19 +++++++++---------- llama_toolchain/inference/event_logger.py | 23 +++++++++++++++-------- llama_toolchain/inference/inference.py | 16 ++++++++-------- 3 files changed, 32 insertions(+), 26 deletions(-) diff --git a/llama_toolchain/inference/client.py b/llama_toolchain/inference/client.py index 3523e1867..3dd646457 100644 --- a/llama_toolchain/inference/client.py +++ b/llama_toolchain/inference/client.py @@ -14,6 +14,7 @@ from termcolor import cprint from .api import ( ChatCompletionRequest, + ChatCompletionResponse, ChatCompletionResponseStreamChunk, CompletionRequest, Inference, @@ -50,35 +51,33 @@ class InferenceClient(Inference): if line.startswith("data:"): data = line[len("data: ") :] try: - yield ChatCompletionResponseStreamChunk(**json.loads(data)) + if request.stream: + yield ChatCompletionResponseStreamChunk(**json.loads(data)) + else: + yield ChatCompletionResponse(**json.loads(data)) except Exception as e: print(data) print(f"Error with parsing or validation: {e}") -async def run_main(host: str, port: int): +async def run_main(host: str, port: int, stream: bool): client = InferenceClient(f"http://{host}:{port}") message = UserMessage(content="hello world, help me out here") cprint(f"User>{message.content}", "green") - req = ChatCompletionRequest( - model=InstructModel.llama3_70b_chat, - messages=[message], - stream=True, - ) iterator = client.chat_completion( ChatCompletionRequest( model=InstructModel.llama3_8b_chat, messages=[message], - stream=True, + stream=stream, ) ) async for log in EventLogger().log(iterator): log.print() -def main(host: str, port: int): - asyncio.run(run_main(host, port)) +def main(host: str, port: int, stream: bool = True): + asyncio.run(run_main(host, port, stream)) if __name__ == "__main__": diff --git a/llama_toolchain/inference/event_logger.py b/llama_toolchain/inference/event_logger.py index 4e29c3614..9d9434b6a 100644 --- a/llama_toolchain/inference/event_logger.py +++ b/llama_toolchain/inference/event_logger.py @@ -6,7 +6,10 @@ from termcolor import cprint -from llama_toolchain.inference.api import ChatCompletionResponseEventType +from llama_toolchain.inference.api import ( + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk +) class LogEvent: @@ -25,12 +28,16 @@ class LogEvent: class EventLogger: - async def log(self, event_generator, stream=True): + async def log(self, event_generator): async for chunk in event_generator: - event = chunk.event - if event.event_type == ChatCompletionResponseEventType.start: + if isinstance(chunk, ChatCompletionResponseStreamChunk): + event = chunk.event + if event.event_type == ChatCompletionResponseEventType.start: + yield LogEvent("Assistant> ", color="cyan", end="") + elif event.event_type == ChatCompletionResponseEventType.progress: + yield LogEvent(event.delta, color="yellow", end="") + elif event.event_type == ChatCompletionResponseEventType.complete: + yield LogEvent("") + else: yield LogEvent("Assistant> ", color="cyan", end="") - elif event.event_type == ChatCompletionResponseEventType.progress: - yield LogEvent(event.delta, color="yellow", end="") - elif event.event_type == ChatCompletionResponseEventType.complete: - yield LogEvent("") + yield LogEvent(chunk.completion_message.content, color="yellow") diff --git a/llama_toolchain/inference/inference.py b/llama_toolchain/inference/inference.py index b49736208..b3fa058fe 100644 --- a/llama_toolchain/inference/inference.py +++ b/llama_toolchain/inference/inference.py @@ -16,6 +16,7 @@ from .api.datatypes import ( ToolCallParseStatus, ) from .api.endpoints import ( + ChatCompletionResponse, ChatCompletionRequest, ChatCompletionResponseStreamChunk, CompletionRequest, @@ -40,12 +41,13 @@ class InferenceImpl(Inference): raise NotImplementedError() async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta="", + if request.stream: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta="", + ) ) - ) tokens = [] logprobs = [] @@ -152,8 +154,6 @@ class InferenceImpl(Inference): # TODO(ashwin): what else do we need to send out here when everything finishes? else: yield ChatCompletionResponse( - content=message.content, - tool_calls=message.tool_calls, - stop_reason=stop_reason, + completion_message=message, logprobs=logprobs if request.logprobs else None, )