fix(gemini.py): implement custom streamer

This commit is contained in:
Krrish Dholakia 2024-02-20 17:10:51 -08:00
parent 7b641491a2
commit 1d3bef2e9c
2 changed files with 24 additions and 6 deletions

View file

@ -1,4 +1,4 @@
import os, types, traceback, copy
import os, types, traceback, copy, asyncio
import json
from enum import Enum
import time
@ -82,6 +82,27 @@ class GeminiConfig:
}
class TextStreamer:
"""
A class designed to return an async stream from AsyncGenerateContentResponse object.
"""
def __init__(self, response):
self.response = response
self._aiter = self.response.__aiter__()
async def __aiter__(self):
while True:
try:
# This will manually advance the async iterator.
# In the case the next object doesn't exists, __anext__() will simply raise a StopAsyncIteration exception
next_object = await self._aiter.__anext__()
yield next_object
except StopAsyncIteration:
# After getting all items from the async iterator, stop iterating
break
def completion(
model: str,
messages: list,
@ -160,12 +181,11 @@ def completion(
)
response = litellm.CustomStreamWrapper(
aiter(response),
TextStreamer(response),
model,
custom_llm_provider="gemini",
logging_obj=logging_obj,
)
return response
return async_streaming()

View file

@ -438,7 +438,7 @@ async def test_acompletion_gemini_stream():
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "how does a court case get to the Supreme Court?",
"content": "What do you know?",
},
]
print("testing gemini streaming")
@ -453,8 +453,6 @@ async def test_acompletion_gemini_stream():
print(f"chunk in acompletion gemini: {chunk}")
print(chunk.choices[0].delta)
chunk, finished = streaming_format_tests(idx, chunk)
if idx > 5:
break
if finished:
break
print(f"chunk: {chunk}")