diff --git a/litellm/llms/gemini.py b/litellm/llms/gemini.py index 03574559ca..8d9994cb6f 100644 --- a/litellm/llms/gemini.py +++ b/litellm/llms/gemini.py @@ -1,4 +1,4 @@ -import os, types, traceback, copy +import os, types, traceback, copy, asyncio import json from enum import Enum import time @@ -82,6 +82,27 @@ class GeminiConfig: } +class TextStreamer: + """ + A class designed to return an async stream from AsyncGenerateContentResponse object. + """ + + def __init__(self, response): + self.response = response + self._aiter = self.response.__aiter__() + + async def __aiter__(self): + while True: + try: + # This will manually advance the async iterator. + # In the case the next object doesn't exists, __anext__() will simply raise a StopAsyncIteration exception + next_object = await self._aiter.__anext__() + yield next_object + except StopAsyncIteration: + # After getting all items from the async iterator, stop iterating + break + + def completion( model: str, messages: list, @@ -160,12 +181,11 @@ def completion( ) response = litellm.CustomStreamWrapper( - aiter(response), + TextStreamer(response), model, custom_llm_provider="gemini", logging_obj=logging_obj, ) - return response return async_streaming() diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index ee6a187e28..58dc25fb05 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -438,7 +438,7 @@ async def test_acompletion_gemini_stream(): {"role": "system", "content": "You are a helpful assistant."}, { "role": "user", - "content": "how does a court case get to the Supreme Court?", + "content": "What do you know?", }, ] print("testing gemini streaming") @@ -453,8 +453,6 @@ async def test_acompletion_gemini_stream(): print(f"chunk in acompletion gemini: {chunk}") print(chunk.choices[0].delta) chunk, finished = streaming_format_tests(idx, chunk) - if idx > 5: - break if finished: break print(f"chunk: {chunk}")