do not swallow first token

This commit is contained in:
Hardik Shah 2025-02-27 19:17:44 -08:00
parent a9f5c5bfca
commit f55d812d8e

View file

@ -6,7 +6,7 @@
import json import json
import logging import logging
import warnings import warnings
from typing import AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union from typing import AsyncGenerator, Dict, Iterable, List, Optional, Union
from openai import AsyncStream from openai import AsyncStream
from openai.types.chat import ( from openai.types.chat import (
@ -843,12 +843,11 @@ async def convert_openai_chat_completion_stream(
""" """
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ... # generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
def _event_type_generator() -> Generator[ChatCompletionResponseEventType, None, None]: # def _event_type_generator() -> Generator[ChatCompletionResponseEventType, None, None]:
yield ChatCompletionResponseEventType.start # yield ChatCompletionResponseEventType.start
while True: # while True:
yield ChatCompletionResponseEventType.progress # yield ChatCompletionResponseEventType.progress
event_type = ChatCompletionResponseEventType.progress
event_type = _event_type_generator()
stop_reason = None stop_reason = None
toolcall_buffer = {} toolcall_buffer = {}
@ -868,7 +867,7 @@ async def convert_openai_chat_completion_stream(
if choice.delta.content: if choice.delta.content:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=next(event_type), event_type=event_type,
delta=TextDelta(text=choice.delta.content), delta=TextDelta(text=choice.delta.content),
logprobs=_convert_openai_logprobs(logprobs), logprobs=_convert_openai_logprobs(logprobs),
) )
@ -909,7 +908,7 @@ async def convert_openai_chat_completion_stream(
toolcall_buffer["content"] += delta toolcall_buffer["content"] += delta
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=next(event_type), event_type=event_type,
delta=ToolCallDelta( delta=ToolCallDelta(
tool_call=delta, tool_call=delta,
parse_status=ToolCallParseStatus.in_progress, parse_status=ToolCallParseStatus.in_progress,
@ -920,7 +919,7 @@ async def convert_openai_chat_completion_stream(
else: else:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=next(event_type), event_type=event_type,
delta=TextDelta(text=choice.delta.content or ""), delta=TextDelta(text=choice.delta.content or ""),
logprobs=_convert_openai_logprobs(logprobs), logprobs=_convert_openai_logprobs(logprobs),
) )
@ -931,7 +930,7 @@ async def convert_openai_chat_completion_stream(
toolcall_buffer["content"] += delta toolcall_buffer["content"] += delta
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=next(event_type), event_type=event_type,
delta=ToolCallDelta( delta=ToolCallDelta(
tool_call=delta, tool_call=delta,
parse_status=ToolCallParseStatus.in_progress, parse_status=ToolCallParseStatus.in_progress,