mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix(gemini.py): implement custom streamer
This commit is contained in:
parent
7b641491a2
commit
1d3bef2e9c
2 changed files with 24 additions and 6 deletions
|
@ -1,4 +1,4 @@
|
||||||
import os, types, traceback, copy
|
import os, types, traceback, copy, asyncio
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import time
|
import time
|
||||||
|
@ -82,6 +82,27 @@ class GeminiConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TextStreamer:
|
||||||
|
"""
|
||||||
|
A class designed to return an async stream from AsyncGenerateContentResponse object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, response):
|
||||||
|
self.response = response
|
||||||
|
self._aiter = self.response.__aiter__()
|
||||||
|
|
||||||
|
async def __aiter__(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# This will manually advance the async iterator.
|
||||||
|
# In the case the next object doesn't exists, __anext__() will simply raise a StopAsyncIteration exception
|
||||||
|
next_object = await self._aiter.__anext__()
|
||||||
|
yield next_object
|
||||||
|
except StopAsyncIteration:
|
||||||
|
# After getting all items from the async iterator, stop iterating
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
model: str,
|
model: str,
|
||||||
messages: list,
|
messages: list,
|
||||||
|
@ -160,12 +181,11 @@ def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
response = litellm.CustomStreamWrapper(
|
response = litellm.CustomStreamWrapper(
|
||||||
aiter(response),
|
TextStreamer(response),
|
||||||
model,
|
model,
|
||||||
custom_llm_provider="gemini",
|
custom_llm_provider="gemini",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
return async_streaming()
|
return async_streaming()
|
||||||
|
|
|
@ -438,7 +438,7 @@ async def test_acompletion_gemini_stream():
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "how does a court case get to the Supreme Court?",
|
"content": "What do you know?",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
print("testing gemini streaming")
|
print("testing gemini streaming")
|
||||||
|
@ -453,8 +453,6 @@ async def test_acompletion_gemini_stream():
|
||||||
print(f"chunk in acompletion gemini: {chunk}")
|
print(f"chunk in acompletion gemini: {chunk}")
|
||||||
print(chunk.choices[0].delta)
|
print(chunk.choices[0].delta)
|
||||||
chunk, finished = streaming_format_tests(idx, chunk)
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
if idx > 5:
|
|
||||||
break
|
|
||||||
if finished:
|
if finished:
|
||||||
break
|
break
|
||||||
print(f"chunk: {chunk}")
|
print(f"chunk: {chunk}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue