From 999195fe5b6416c092cbf8c5dabc2d221dad33f1 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Thu, 27 Feb 2025 20:53:47 -0800 Subject: [PATCH] fix: [Litellm]Do not swallow first token (#1316) `ChatCompletionResponseEventType: start` is ignored and not yielded in the agent_instance as we expect that to not have any content. However, litellm sends first event as `ChatCompletionResponseEventType: start` with content ( which was the first token that we were skipping ) ``` LLAMA_STACK_CONFIG=dev pytest -s -v tests/client-sdk/agents/test_agents.py --inference-model "openai/gpt-4o-mini" -k test_agent_simple ``` This was failing before ( since the word hello was not in the final response ) --- .../utils/inference/openai_compat.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index eaf5ad2e1..d0fdf6385 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -6,7 +6,7 @@ import json import logging import warnings -from typing import AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union +from typing import AsyncGenerator, Dict, Iterable, List, Optional, Union from openai import AsyncStream from openai.types.chat import ( @@ -841,14 +841,13 @@ async def convert_openai_chat_completion_stream( Convert a stream of OpenAI chat completion chunks into a stream of ChatCompletionResponseStreamChunk. """ - - # generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ... - def _event_type_generator() -> Generator[ChatCompletionResponseEventType, None, None]: - yield ChatCompletionResponseEventType.start - while True: - yield ChatCompletionResponseEventType.progress - - event_type = _event_type_generator() + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta=TextDelta(text=""), + ) + ) + event_type = ChatCompletionResponseEventType.progress stop_reason = None toolcall_buffer = {} @@ -868,7 +867,7 @@ async def convert_openai_chat_completion_stream( if choice.delta.content: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type=next(event_type), + event_type=event_type, delta=TextDelta(text=choice.delta.content), logprobs=_convert_openai_logprobs(logprobs), ) @@ -909,7 +908,7 @@ async def convert_openai_chat_completion_stream( toolcall_buffer["content"] += delta yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type=next(event_type), + event_type=event_type, delta=ToolCallDelta( tool_call=delta, parse_status=ToolCallParseStatus.in_progress, @@ -920,7 +919,7 @@ async def convert_openai_chat_completion_stream( else: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type=next(event_type), + event_type=event_type, delta=TextDelta(text=choice.delta.content or ""), logprobs=_convert_openai_logprobs(logprobs), ) @@ -931,7 +930,7 @@ async def convert_openai_chat_completion_stream( toolcall_buffer["content"] += delta yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type=next(event_type), + event_type=event_type, delta=ToolCallDelta( tool_call=delta, parse_status=ToolCallParseStatus.in_progress,