fix non-streaming api in inference server

This commit is contained in:
Hardik Shah 2024-07-30 14:25:50 -07:00
parent 404af06e02
commit cc98fbb058
3 changed files with 32 additions and 26 deletions

View file

@ -16,6 +16,7 @@ from .api.datatypes import (
ToolCallParseStatus,
)
from .api.endpoints import (
ChatCompletionResponse,
ChatCompletionRequest,
ChatCompletionResponseStreamChunk,
CompletionRequest,
@ -40,12 +41,13 @@ class InferenceImpl(Inference):
raise NotImplementedError()
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
if request.stream:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
)
tokens = []
logprobs = []
@ -152,8 +154,6 @@ class InferenceImpl(Inference):
# TODO(ashwin): what else do we need to send out here when everything finishes?
else:
yield ChatCompletionResponse(
content=message.content,
tool_calls=message.tool_calls,
stop_reason=stop_reason,
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)