Separate chat_completion stream and non-stream implementations

This is a pretty important requirement. The streaming response type is
an AsyncGenerator while the non-stream one is a single object. So far
this has worked _sometimes_ due to various pre-existing hacks (and in
some cases, just failed.)
This commit is contained in:
Ashwin Bharambe 2024-10-08 10:52:16 -07:00 committed by Ashwin Bharambe
parent f8752ab8dc
commit 0c9eb3341c
5 changed files with 346 additions and 287 deletions

View file

@ -46,9 +46,7 @@ class MetaReferenceInferenceImpl(Inference):
async def shutdown(self) -> None:
self.generator.stop()
# hm, when stream=False, we should not be doing SSE :/ which is what the
# top-level server is going to do. make the typing more specific here
async def chat_completion(
def chat_completion(
self,
model: str,
messages: List[Message],
@ -59,6 +57,9 @@ class MetaReferenceInferenceImpl(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model,
@ -71,7 +72,6 @@ class MetaReferenceInferenceImpl(Inference):
logprobs=logprobs,
)
messages = augment_messages_for_tools(request)
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
@ -87,138 +87,163 @@ class MetaReferenceInferenceImpl(Inference):
async with SEMAPHORE:
if request.stream:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
messages = augment_messages_for_tools(request)
tokens = []
logprobs = []
stop_reason = None
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={token_result.text: token_result.logprobs[0]}
)
)
tokens = []
logprobs = []
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
stop_reason = None
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
return ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)
buffer = ""
ipython = False
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
messages = augment_messages_for_tools(request)
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
buffer += token_result.text
tokens.append(token_result.token)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
if not ipython and buffer.startswith("<|python_tag|>"):
ipython = True
if request.stream:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
tokens = []
logprobs = []
stop_reason = None
ipython = False
buffer = buffer[len("<|python_tag|>") :]
continue
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
tokens.append(token_result.token)
if not request.stream:
if request.logprobs:
assert (
len(token_result.logprobs) == 1
), "Expected logprob to contain 1 result for the current token"
assert (
request.logprobs.top_k == 1
), "Only top_k=1 is supported for LogProbConfig"
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
continue
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
if not ipython and token_result.text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
else:
delta = text
)
continue
if stop_reason is None:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
if request.logprobs:
assert len(token_result.logprobs) == 1
# TODO(ashwin): parse tool calls separately here and report errors?
# if someone breaks the iteration before coming here we are toast
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
if request.stream:
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
)
# TODO(ashwin): what else do we need to send out here when everything finishes?
else:
yield ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)