fix(gemini.py): implement custom streamer

This commit is contained in:
Krrish Dholakia 2024-02-20 17:10:51 -08:00
parent 7b641491a2
commit 1d3bef2e9c
2 changed files with 24 additions and 6 deletions

View file

@ -1,4 +1,4 @@
import os, types, traceback, copy import os, types, traceback, copy, asyncio
import json import json
from enum import Enum from enum import Enum
import time import time
@ -82,6 +82,27 @@ class GeminiConfig:
} }
class TextStreamer:
"""
A class designed to return an async stream from AsyncGenerateContentResponse object.
"""
def __init__(self, response):
self.response = response
self._aiter = self.response.__aiter__()
async def __aiter__(self):
while True:
try:
# This will manually advance the async iterator.
# In the case the next object doesn't exists, __anext__() will simply raise a StopAsyncIteration exception
next_object = await self._aiter.__anext__()
yield next_object
except StopAsyncIteration:
# After getting all items from the async iterator, stop iterating
break
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -160,12 +181,11 @@ def completion(
) )
response = litellm.CustomStreamWrapper( response = litellm.CustomStreamWrapper(
aiter(response), TextStreamer(response),
model, model,
custom_llm_provider="gemini", custom_llm_provider="gemini",
logging_obj=logging_obj, logging_obj=logging_obj,
) )
return response return response
return async_streaming() return async_streaming()

View file

@ -438,7 +438,7 @@ async def test_acompletion_gemini_stream():
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{ {
"role": "user", "role": "user",
"content": "how does a court case get to the Supreme Court?", "content": "What do you know?",
}, },
] ]
print("testing gemini streaming") print("testing gemini streaming")
@ -453,8 +453,6 @@ async def test_acompletion_gemini_stream():
print(f"chunk in acompletion gemini: {chunk}") print(f"chunk in acompletion gemini: {chunk}")
print(chunk.choices[0].delta) print(chunk.choices[0].delta)
chunk, finished = streaming_format_tests(idx, chunk) chunk, finished = streaming_format_tests(idx, chunk)
if idx > 5:
break
if finished: if finished:
break break
print(f"chunk: {chunk}") print(f"chunk: {chunk}")