fix(vertex_httpx.py): support async streaming for google ai studio gemini

This commit is contained in:
Krrish Dholakia 2024-06-17 14:59:30 -07:00
parent be66800a98
commit e92570534c

View file

@ -12,6 +12,7 @@ from functools import partial
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore
import ijson
import requests # type: ignore
import litellm
@ -257,7 +258,7 @@ async def make_call(
raise VertexAIError(status_code=response.status_code, message=response.text)
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_bytes(chunk_size=2056)
streaming_response=response.aiter_bytes(), sync_stream=False
)
# LOGGING
logging_obj.post_call(
@ -288,7 +289,7 @@ def make_sync_call(
raise VertexAIError(status_code=response.status_code, message=response.read())
completion_stream = ModelResponseIterator(
streaming_response=response.iter_bytes(chunk_size=2056)
streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True
)
# LOGGING
@ -705,6 +706,25 @@ class VertexLLM(BaseLLM):
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
### ASYNC STREAMING
if stream is True:
return self.async_streaming(
model=model,
messages=messages,
data=json.dumps(data), # type: ignore
api_base=url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client, # type: ignore
)
### ASYNC COMPLETION
return self.async_completion(
model=model,
@ -916,9 +936,13 @@ class VertexLLM(BaseLLM):
class ModelResponseIterator:
def __init__(self, streaming_response):
def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response
self.response_iterator = iter(self.streaming_response)
if sync_stream:
self.response_iterator = iter(self.streaming_response)
self.events = ijson.sendable_list()
self.coro = ijson.items_coro(self.events, "item")
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
@ -970,10 +994,21 @@ class ModelResponseIterator:
def __next__(self):
try:
chunk = next(self.response_iterator)
chunk = chunk.decode()
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
chunk = self.response_iterator.__next__()
self.coro.send(chunk)
if self.events:
event = self.events[0]
json_chunk = event
self.events.clear()
return self.chunk_parser(chunk=json_chunk)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopIteration:
raise StopIteration
except ValueError as e:
@ -987,9 +1022,20 @@ class ModelResponseIterator:
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
chunk = chunk.decode()
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
self.coro.send(chunk)
if self.events:
event = self.events[0]
json_chunk = event
self.events.clear()
return self.chunk_parser(chunk=json_chunk)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e: