diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index eaf5ad2e1..d0fdf6385 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -6,7 +6,7 @@ import json import logging import warnings -from typing import AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union +from typing import AsyncGenerator, Dict, Iterable, List, Optional, Union from openai import AsyncStream from openai.types.chat import ( @@ -841,14 +841,13 @@ async def convert_openai_chat_completion_stream( Convert a stream of OpenAI chat completion chunks into a stream of ChatCompletionResponseStreamChunk. """ - - # generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ... - def _event_type_generator() -> Generator[ChatCompletionResponseEventType, None, None]: - yield ChatCompletionResponseEventType.start - while True: - yield ChatCompletionResponseEventType.progress - - event_type = _event_type_generator() + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta=TextDelta(text=""), + ) + ) + event_type = ChatCompletionResponseEventType.progress stop_reason = None toolcall_buffer = {} @@ -868,7 +867,7 @@ async def convert_openai_chat_completion_stream( if choice.delta.content: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type=next(event_type), + event_type=event_type, delta=TextDelta(text=choice.delta.content), logprobs=_convert_openai_logprobs(logprobs), ) @@ -909,7 +908,7 @@ async def convert_openai_chat_completion_stream( toolcall_buffer["content"] += delta yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type=next(event_type), + event_type=event_type, delta=ToolCallDelta( tool_call=delta, parse_status=ToolCallParseStatus.in_progress, @@ -920,7 +919,7 @@ async def convert_openai_chat_completion_stream( else: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type=next(event_type), + event_type=event_type, delta=TextDelta(text=choice.delta.content or ""), logprobs=_convert_openai_logprobs(logprobs), ) @@ -931,7 +930,7 @@ async def convert_openai_chat_completion_stream( toolcall_buffer["content"] += delta yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type=next(event_type), + event_type=event_type, delta=ToolCallDelta( tool_call=delta, parse_status=ToolCallParseStatus.in_progress,