mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(gemini.py): implement custom streamer
This commit is contained in:
parent
7b641491a2
commit
1d3bef2e9c
2 changed files with 24 additions and 6 deletions
|
@ -1,4 +1,4 @@
|
|||
import os, types, traceback, copy
|
||||
import os, types, traceback, copy, asyncio
|
||||
import json
|
||||
from enum import Enum
|
||||
import time
|
||||
|
@ -82,6 +82,27 @@ class GeminiConfig:
|
|||
}
|
||||
|
||||
|
||||
class TextStreamer:
|
||||
"""
|
||||
A class designed to return an async stream from AsyncGenerateContentResponse object.
|
||||
"""
|
||||
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
self._aiter = self.response.__aiter__()
|
||||
|
||||
async def __aiter__(self):
|
||||
while True:
|
||||
try:
|
||||
# This will manually advance the async iterator.
|
||||
# In the case the next object doesn't exists, __anext__() will simply raise a StopAsyncIteration exception
|
||||
next_object = await self._aiter.__anext__()
|
||||
yield next_object
|
||||
except StopAsyncIteration:
|
||||
# After getting all items from the async iterator, stop iterating
|
||||
break
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
|
@ -160,12 +181,11 @@ def completion(
|
|||
)
|
||||
|
||||
response = litellm.CustomStreamWrapper(
|
||||
aiter(response),
|
||||
TextStreamer(response),
|
||||
model,
|
||||
custom_llm_provider="gemini",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
return async_streaming()
|
||||
|
|
|
@ -438,7 +438,7 @@ async def test_acompletion_gemini_stream():
|
|||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how does a court case get to the Supreme Court?",
|
||||
"content": "What do you know?",
|
||||
},
|
||||
]
|
||||
print("testing gemini streaming")
|
||||
|
@ -453,8 +453,6 @@ async def test_acompletion_gemini_stream():
|
|||
print(f"chunk in acompletion gemini: {chunk}")
|
||||
print(chunk.choices[0].delta)
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
if idx > 5:
|
||||
break
|
||||
if finished:
|
||||
break
|
||||
print(f"chunk: {chunk}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue