fix(vertex_httpx.py): support async streaming for google ai studio gemini

This commit is contained in:
Krrish Dholakia 2024-06-17 14:59:30 -07:00
parent be66800a98
commit e92570534c

View file

@ -12,6 +12,7 @@ from functools import partial
from typing import Any, Callable, List, Literal, Optional, Tuple, Union from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore import httpx # type: ignore
import ijson
import requests # type: ignore import requests # type: ignore
import litellm import litellm
@ -257,7 +258,7 @@ async def make_call(
raise VertexAIError(status_code=response.status_code, message=response.text) raise VertexAIError(status_code=response.status_code, message=response.text)
completion_stream = ModelResponseIterator( completion_stream = ModelResponseIterator(
streaming_response=response.aiter_bytes(chunk_size=2056) streaming_response=response.aiter_bytes(), sync_stream=False
) )
# LOGGING # LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -288,7 +289,7 @@ def make_sync_call(
raise VertexAIError(status_code=response.status_code, message=response.read()) raise VertexAIError(status_code=response.status_code, message=response.read())
completion_stream = ModelResponseIterator( completion_stream = ModelResponseIterator(
streaming_response=response.iter_bytes(chunk_size=2056) streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True
) )
# LOGGING # LOGGING
@ -705,6 +706,25 @@ class VertexLLM(BaseLLM):
### ROUTING (ASYNC, STREAMING, SYNC) ### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion: if acompletion:
### ASYNC STREAMING
if stream is True:
return self.async_streaming(
model=model,
messages=messages,
data=json.dumps(data), # type: ignore
api_base=url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client, # type: ignore
)
### ASYNC COMPLETION ### ASYNC COMPLETION
return self.async_completion( return self.async_completion(
model=model, model=model,
@ -916,9 +936,13 @@ class VertexLLM(BaseLLM):
class ModelResponseIterator: class ModelResponseIterator:
def __init__(self, streaming_response): def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response self.streaming_response = streaming_response
self.response_iterator = iter(self.streaming_response) if sync_stream:
self.response_iterator = iter(self.streaming_response)
self.events = ijson.sendable_list()
self.coro = ijson.items_coro(self.events, "item")
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try: try:
@ -970,10 +994,21 @@ class ModelResponseIterator:
def __next__(self): def __next__(self):
try: try:
chunk = next(self.response_iterator) chunk = self.response_iterator.__next__()
chunk = chunk.decode() self.coro.send(chunk)
json_chunk = json.loads(chunk) if self.events:
return self.chunk_parser(chunk=json_chunk) event = self.events[0]
json_chunk = event
self.events.clear()
return self.chunk_parser(chunk=json_chunk)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopIteration: except StopIteration:
raise StopIteration raise StopIteration
except ValueError as e: except ValueError as e:
@ -987,9 +1022,20 @@ class ModelResponseIterator:
async def __anext__(self): async def __anext__(self):
try: try:
chunk = await self.async_response_iterator.__anext__() chunk = await self.async_response_iterator.__anext__()
chunk = chunk.decode() self.coro.send(chunk)
json_chunk = json.loads(chunk) if self.events:
return self.chunk_parser(chunk=json_chunk) event = self.events[0]
json_chunk = event
self.events.clear()
return self.chunk_parser(chunk=json_chunk)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopAsyncIteration: except StopAsyncIteration:
raise StopAsyncIteration raise StopAsyncIteration
except ValueError as e: except ValueError as e: