fix non-streaming api in inference server

This commit is contained in:
Hardik Shah 2024-07-30 14:25:50 -07:00
parent 404af06e02
commit cc98fbb058
3 changed files with 32 additions and 26 deletions

View file

@ -14,6 +14,7 @@ from termcolor import cprint
from .api import ( from .api import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
CompletionRequest, CompletionRequest,
Inference, Inference,
@ -50,35 +51,33 @@ class InferenceClient(Inference):
if line.startswith("data:"): if line.startswith("data:"):
data = line[len("data: ") :] data = line[len("data: ") :]
try: try:
yield ChatCompletionResponseStreamChunk(**json.loads(data)) if request.stream:
yield ChatCompletionResponseStreamChunk(**json.loads(data))
else:
yield ChatCompletionResponse(**json.loads(data))
except Exception as e: except Exception as e:
print(data) print(data)
print(f"Error with parsing or validation: {e}") print(f"Error with parsing or validation: {e}")
async def run_main(host: str, port: int): async def run_main(host: str, port: int, stream: bool):
client = InferenceClient(f"http://{host}:{port}") client = InferenceClient(f"http://{host}:{port}")
message = UserMessage(content="hello world, help me out here") message = UserMessage(content="hello world, help me out here")
cprint(f"User>{message.content}", "green") cprint(f"User>{message.content}", "green")
req = ChatCompletionRequest(
model=InstructModel.llama3_70b_chat,
messages=[message],
stream=True,
)
iterator = client.chat_completion( iterator = client.chat_completion(
ChatCompletionRequest( ChatCompletionRequest(
model=InstructModel.llama3_8b_chat, model=InstructModel.llama3_8b_chat,
messages=[message], messages=[message],
stream=True, stream=stream,
) )
) )
async for log in EventLogger().log(iterator): async for log in EventLogger().log(iterator):
log.print() log.print()
def main(host: str, port: int): def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port)) asyncio.run(run_main(host, port, stream))
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -6,7 +6,10 @@
from termcolor import cprint from termcolor import cprint
from llama_toolchain.inference.api import ChatCompletionResponseEventType from llama_toolchain.inference.api import (
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk
)
class LogEvent: class LogEvent:
@ -25,12 +28,16 @@ class LogEvent:
class EventLogger: class EventLogger:
async def log(self, event_generator, stream=True): async def log(self, event_generator):
async for chunk in event_generator: async for chunk in event_generator:
event = chunk.event if isinstance(chunk, ChatCompletionResponseStreamChunk):
if event.event_type == ChatCompletionResponseEventType.start: event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
yield LogEvent("Assistant> ", color="cyan", end="")
elif event.event_type == ChatCompletionResponseEventType.progress:
yield LogEvent(event.delta, color="yellow", end="")
elif event.event_type == ChatCompletionResponseEventType.complete:
yield LogEvent("")
else:
yield LogEvent("Assistant> ", color="cyan", end="") yield LogEvent("Assistant> ", color="cyan", end="")
elif event.event_type == ChatCompletionResponseEventType.progress: yield LogEvent(chunk.completion_message.content, color="yellow")
yield LogEvent(event.delta, color="yellow", end="")
elif event.event_type == ChatCompletionResponseEventType.complete:
yield LogEvent("")

View file

@ -16,6 +16,7 @@ from .api.datatypes import (
ToolCallParseStatus, ToolCallParseStatus,
) )
from .api.endpoints import ( from .api.endpoints import (
ChatCompletionResponse,
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
CompletionRequest, CompletionRequest,
@ -40,12 +41,13 @@ class InferenceImpl(Inference):
raise NotImplementedError() raise NotImplementedError()
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk( if request.stream:
event=ChatCompletionResponseEvent( yield ChatCompletionResponseStreamChunk(
event_type=ChatCompletionResponseEventType.start, event=ChatCompletionResponseEvent(
delta="", event_type=ChatCompletionResponseEventType.start,
delta="",
)
) )
)
tokens = [] tokens = []
logprobs = [] logprobs = []
@ -152,8 +154,6 @@ class InferenceImpl(Inference):
# TODO(ashwin): what else do we need to send out here when everything finishes? # TODO(ashwin): what else do we need to send out here when everything finishes?
else: else:
yield ChatCompletionResponse( yield ChatCompletionResponse(
content=message.content, completion_message=message,
tool_calls=message.tool_calls,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None, logprobs=logprobs if request.logprobs else None,
) )